@@ -77,8 +77,6 @@ static PyType_Slot s2s_slots[] = {
77
77
78
78
static char * s2s_name = "cast_StringDType_to_StringDType" ;
79
79
80
- static PyArray_DTypeMeta * s2s_dtypes [] = {NULL , NULL };
81
-
82
80
// unicode to string
83
81
84
82
static NPY_CASTING
@@ -374,6 +372,66 @@ static PyType_Slot s2u_slots[] = {
374
372
375
373
static char * s2u_name = "cast_StringDType_to_Unicode" ;
376
374
375
+ // string to bool
376
+
377
+ static NPY_CASTING
378
+ string_to_bool_resolve_descriptors (PyObject * NPY_UNUSED (self ),
379
+ PyArray_DTypeMeta * NPY_UNUSED (dtypes [2 ]),
380
+ PyArray_Descr * given_descrs [2 ],
381
+ PyArray_Descr * loop_descrs [2 ],
382
+ npy_intp * NPY_UNUSED (view_offset ))
383
+ {
384
+ if (given_descrs [1 ] == NULL ) {
385
+ loop_descrs [1 ] = PyArray_DescrNewFromType (NPY_BOOL );
386
+ }
387
+ else {
388
+ Py_INCREF (given_descrs [1 ]);
389
+ loop_descrs [1 ] = given_descrs [1 ];
390
+ }
391
+
392
+ Py_INCREF (given_descrs [0 ]);
393
+ loop_descrs [0 ] = given_descrs [0 ];
394
+
395
+ return NPY_UNSAFE_CASTING ;
396
+ }
397
+
398
+ static int
399
+ string_to_bool (PyArrayMethod_Context * context , char * const data [],
400
+ npy_intp const dimensions [], npy_intp const strides [],
401
+ NpyAuxData * NPY_UNUSED (auxdata ))
402
+ {
403
+ npy_intp N = dimensions [0 ];
404
+ char * in = data [0 ];
405
+ char * out = data [1 ];
406
+
407
+ npy_intp in_stride = strides [0 ];
408
+ npy_intp out_stride = strides [1 ];
409
+
410
+ ss * s = NULL ;
411
+
412
+ while (N -- ) {
413
+ load_string (in , & s );
414
+ if (s -> len == 0 ) {
415
+ * out = (npy_bool )0 ;
416
+ }
417
+ else {
418
+ * out = (npy_bool )1 ;
419
+ }
420
+
421
+ in += in_stride ;
422
+ out += out_stride ;
423
+ }
424
+
425
+ return 0 ;
426
+ }
427
+
428
+ static PyType_Slot s2b_slots [] = {
429
+ {NPY_METH_resolve_descriptors , & string_to_bool_resolve_descriptors },
430
+ {NPY_METH_strided_loop , & string_to_bool },
431
+ {0 , NULL }};
432
+
433
+ static char * s2b_name = "cast_StringDType_to_Bool" ;
434
+
377
435
PyArrayMethod_Spec *
378
436
get_cast_spec (const char * name , NPY_CASTING casting ,
379
437
NPY_ARRAYMETHOD_FLAGS flags , PyArray_DTypeMeta * * dtypes ,
@@ -406,6 +464,8 @@ get_dtypes(PyArray_DTypeMeta *dt1, PyArray_DTypeMeta *dt2)
406
464
PyArrayMethod_Spec * *
407
465
get_casts (void )
408
466
{
467
+ PyArray_DTypeMeta * * s2s_dtypes = get_dtypes (NULL , NULL );
468
+
409
469
PyArrayMethod_Spec * StringToStringCastSpec =
410
470
get_cast_spec (s2s_name , NPY_NO_CASTING ,
411
471
NPY_METH_SUPPORTS_UNALIGNED , s2s_dtypes , s2s_slots );
@@ -422,11 +482,18 @@ get_casts(void)
422
482
s2u_name , NPY_SAFE_CASTING , NPY_METH_NO_FLOATINGPOINT_ERRORS ,
423
483
s2u_dtypes , s2u_slots );
424
484
425
- PyArrayMethod_Spec * * casts = malloc (4 * sizeof (PyArrayMethod_Spec * ));
485
+ PyArray_DTypeMeta * * s2b_dtypes = get_dtypes (NULL , & PyArray_BoolDType );
486
+
487
+ PyArrayMethod_Spec * StringToBoolCastSpec = get_cast_spec (
488
+ s2b_name , NPY_UNSAFE_CASTING , NPY_METH_NO_FLOATINGPOINT_ERRORS ,
489
+ s2b_dtypes , s2b_slots );
490
+
491
+ PyArrayMethod_Spec * * casts = malloc (5 * sizeof (PyArrayMethod_Spec * ));
426
492
casts [0 ] = StringToStringCastSpec ;
427
493
casts [1 ] = UnicodeToStringCastSpec ;
428
494
casts [2 ] = StringToUnicodeCastSpec ;
429
- casts [3 ] = NULL ;
495
+ casts [3 ] = StringToBoolCastSpec ;
496
+ casts [4 ] = NULL ;
430
497
431
498
return casts ;
432
499
}
0 commit comments