Skip to content

Commit b387b2a

Browse files
authored
Merge pull request #80 from ngoldbaum/float-cast-fix
use numpy scalar in float to string cast
2 parents a1c811f + 13ca7e8 commit b387b2a

File tree

2 files changed

+39
-33
lines changed

2 files changed

+39
-33
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -826,39 +826,39 @@ string_to_pyfloat(char *in)
826826
return NPY_UNSAFE_CASTING; \
827827
}
828828

829-
#define FLOAT_TO_STRING_CAST(typename, shortname, float_to_double) \
830-
static int typename##_to_string( \
831-
PyArrayMethod_Context *NPY_UNUSED(context), char *const data[], \
832-
npy_intp const dimensions[], npy_intp const strides[], \
833-
NpyAuxData *NPY_UNUSED(auxdata)) \
834-
{ \
835-
npy_intp N = dimensions[0]; \
836-
npy_##typename *in = (npy_##typename *)data[0]; \
837-
char *out = data[1]; \
838-
\
839-
npy_intp in_stride = strides[0] / sizeof(npy_##typename); \
840-
npy_intp out_stride = strides[1]; \
841-
\
842-
while (N--) { \
843-
PyObject *pyfloat_val = \
844-
PyFloat_FromDouble((float_to_double)(*in)); \
845-
if (pyobj_to_string(pyfloat_val, out) == -1) { \
846-
return -1; \
847-
} \
848-
\
849-
in += in_stride; \
850-
out += out_stride; \
851-
} \
852-
\
853-
return 0; \
854-
} \
855-
\
856-
static PyType_Slot shortname##2s_slots [] = { \
857-
{NPY_METH_resolve_descriptors, \
858-
&any_to_string_UNSAFE_resolve_descriptors}, \
859-
{NPY_METH_strided_loop, &typename##_to_string}, \
860-
{0, NULL}}; \
861-
\
829+
#define FLOAT_TO_STRING_CAST(typename, shortname, float_to_double) \
830+
static int typename##_to_string( \
831+
PyArrayMethod_Context *context, char *const data[], \
832+
npy_intp const dimensions[], npy_intp const strides[], \
833+
NpyAuxData *NPY_UNUSED(auxdata)) \
834+
{ \
835+
npy_intp N = dimensions[0]; \
836+
npy_##typename *in = (npy_##typename *)data[0]; \
837+
char *out = data[1]; \
838+
PyArray_Descr *float_descr = context->descriptors[0]; \
839+
\
840+
npy_intp in_stride = strides[0] / sizeof(npy_##typename); \
841+
npy_intp out_stride = strides[1]; \
842+
\
843+
while (N--) { \
844+
PyObject *scalar_val = PyArray_Scalar(in, float_descr, NULL); \
845+
if (pyobj_to_string(scalar_val, out) == -1) { \
846+
return -1; \
847+
} \
848+
\
849+
in += in_stride; \
850+
out += out_stride; \
851+
} \
852+
\
853+
return 0; \
854+
} \
855+
\
856+
static PyType_Slot shortname##2s_slots [] = { \
857+
{NPY_METH_resolve_descriptors, \
858+
&any_to_string_UNSAFE_resolve_descriptors}, \
859+
{NPY_METH_strided_loop, &typename##_to_string}, \
860+
{0, NULL}}; \
861+
\
862862
static char *shortname##2s_name = "cast_" #typename "_to_StringDType";
863863

864864
STRING_TO_FLOAT_RESOLVE_DESCRIPTORS(float64, DOUBLE)

stringdtype/tests/test_stringdtype.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,12 @@ def test_float_casts(dtype, typename):
420420
res = np.array(inp, dtype=typename).astype(dtype).astype(typename)
421421
np.testing.assert_array_equal(eres, res)
422422

423+
inp = [0.1]
424+
sres = np.array(inp, dtype=typename).astype(dtype)
425+
res = sres.astype(typename)
426+
np.testing.assert_array_equal(np.array(inp, dtype=typename), res)
427+
assert sres[0] == "0.1"
428+
423429

424430
def test_take(dtype, string_list):
425431
sarr = np.array(string_list, dtype=dtype)

0 commit comments

Comments
 (0)