Skip to content

Commit 84c4dfd

Browse files
committed
adapt to size
1 parent 77def74 commit 84c4dfd

File tree

4 files changed

+307
-15
lines changed

4 files changed

+307
-15
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 270 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ extern "C" {
1313
#include "numpy/ndarraytypes.h"
1414
#include "numpy/dtype_api.h"
1515
}
16+
#include <cstring>
1617
#include "sleef.h"
1718
#include "sleefquad.h"
1819

@@ -21,8 +22,11 @@ extern "C" {
2122
#include "casts.h"
2223
#include "dtype.h"
2324
#include "utilities.h"
25+
#include "lock.h"
26+
#include "dragon4.h"
27+
#include "ops.hpp"
2428

25-
#define NUM_CASTS 35 // 16 to_casts + 16 from_casts + 1 quad_to_quad + 1 void_to_quad + 1 unicode_to_quad
29+
#define NUM_CASTS 36 // 17 to_casts + 17 from_casts + 1 quad_to_quad + 1 void_to_quad
2630
#define QUAD_STR_WIDTH 50 // 42 is enough for scientific notation float128, just keeping some buffer
2731

2832
static NPY_CASTING
@@ -318,6 +322,246 @@ unicode_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const
318322
return 0;
319323
}
320324

325+
// QuadDType to unicode/string
326+
static NPY_CASTING
327+
quad_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *dtypes[2],
328+
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
329+
npy_intp *view_offset)
330+
{
331+
Py_INCREF(given_descrs[0]);
332+
loop_descrs[0] = given_descrs[0];
333+
334+
if (given_descrs[1] == NULL) {
335+
PyArray_Descr *unicode_descr = PyArray_DescrNewFromType(NPY_UNICODE);
336+
if (unicode_descr == nullptr) {
337+
return (NPY_CASTING)-1;
338+
}
339+
340+
unicode_descr->elsize = QUAD_STR_WIDTH * 4; // bytes
341+
loop_descrs[1] = unicode_descr;
342+
}
343+
else {
344+
Py_INCREF(given_descrs[1]);
345+
loop_descrs[1] = given_descrs[1];
346+
}
347+
348+
*view_offset = 0;
349+
return NPY_UNSAFE_CASTING;
350+
}
351+
352+
static int
353+
quad_to_unicode_loop_unaligned(PyArrayMethod_Context *context, char *const data[],
354+
npy_intp const dimensions[], npy_intp const strides[], void *NPY_UNUSED(auxdata))
355+
{
356+
npy_intp N = dimensions[0];
357+
char *in_ptr = data[0];
358+
char *out_ptr = data[1];
359+
npy_intp in_stride = strides[0];
360+
npy_intp out_stride = strides[1];
361+
362+
PyArray_Descr *const *descrs = context->descriptors;
363+
QuadPrecDTypeObject *descr_in = (QuadPrecDTypeObject *)descrs[0];
364+
QuadBackendType backend = descr_in->backend;
365+
366+
npy_intp unicode_size_chars = descrs[1]->elsize / 4;
367+
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
368+
369+
while (N--)
370+
{
371+
quad_value in_val;
372+
if(backend == BACKEND_SLEEF) {
373+
memcpy(&in_val.sleef_value, in_ptr, sizeof(Sleef_quad));
374+
} else {
375+
memcpy(&in_val.longdouble_value, in_ptr, sizeof(long double));
376+
}
377+
378+
// Convert to Sleef_quad for Dragon4
379+
Sleef_quad sleef_val;
380+
if(backend == BACKEND_SLEEF) {
381+
sleef_val = in_val.sleef_value;
382+
} else {
383+
sleef_val = Sleef_cast_from_doubleq1(in_val.longdouble_value);
384+
}
385+
386+
// If positional format fits, use it; otherwise use scientific notation
387+
PyObject *py_str;
388+
PyObject *positional_str = Dragon4_Positional_QuadDType(&sleef_val,
389+
DigitMode_Unique,
390+
CutoffMode_TotalLength,
391+
SLEEF_QUAD_DECIMAL_DIG,
392+
0,
393+
1,
394+
TrimMode_LeaveOneZero,
395+
1,
396+
0);
397+
398+
if (positional_str == NULL) {
399+
return -1;
400+
}
401+
402+
const char *pos_str = PyUnicode_AsUTF8(positional_str);
403+
if (pos_str == NULL) {
404+
Py_DECREF(positional_str);
405+
return -1;
406+
}
407+
408+
npy_intp pos_len = strlen(pos_str);
409+
410+
411+
if (pos_len <= unicode_size_chars) {
412+
py_str = positional_str; // Keep the positional string
413+
} else {
414+
Py_DECREF(positional_str);
415+
// Use scientific notation with full precision
416+
py_str = Dragon4_Scientific_QuadDType(&sleef_val,
417+
DigitMode_Unique,
418+
SLEEF_QUAD_DECIMAL_DIG,
419+
0,
420+
1,
421+
TrimMode_LeaveOneZero,
422+
1,
423+
2);
424+
if (py_str == NULL) {
425+
return -1;
426+
}
427+
}
428+
429+
const char *temp_str = PyUnicode_AsUTF8(py_str);
430+
if (temp_str == NULL) {
431+
Py_DECREF(py_str);
432+
return -1;
433+
}
434+
435+
// Convert char string to UCS4 and store in output
436+
Py_UCS4* out_ucs4 = (Py_UCS4 *)out_ptr;
437+
npy_intp str_len = strlen(temp_str);
438+
439+
for(npy_intp i = 0; i < unicode_size_chars; i++)
440+
{
441+
if(i < str_len)
442+
{
443+
out_ucs4[i] = (Py_UCS4)temp_str[i];
444+
}
445+
else
446+
{
447+
out_ucs4[i] = 0;
448+
}
449+
}
450+
451+
Py_DECREF(py_str);
452+
453+
in_ptr += in_stride;
454+
out_ptr += out_stride;
455+
}
456+
457+
return 0;
458+
}
459+
460+
static int
461+
quad_to_unicode_loop_aligned(PyArrayMethod_Context *context, char *const data[],
462+
npy_intp const dimensions[], npy_intp const strides[], void *NPY_UNUSED(auxdata))
463+
{
464+
npy_intp N = dimensions[0];
465+
char *in_ptr = data[0];
466+
char *out_ptr = data[1];
467+
npy_intp in_stride = strides[0];
468+
npy_intp out_stride = strides[1];
469+
470+
PyArray_Descr *const *descrs = context->descriptors;
471+
QuadPrecDTypeObject *descr_in = (QuadPrecDTypeObject *)descrs[0];
472+
QuadBackendType backend = descr_in->backend;
473+
474+
npy_intp unicode_size_chars = descrs[1]->elsize / 4;
475+
476+
while (N--)
477+
{
478+
quad_value in_val;
479+
if(backend == BACKEND_SLEEF) {
480+
in_val.sleef_value = *(Sleef_quad *)in_ptr;
481+
} else {
482+
in_val.longdouble_value = *(long double *)in_ptr;
483+
}
484+
485+
// Convert to Sleef_quad for Dragon4
486+
Sleef_quad sleef_val;
487+
if(backend == BACKEND_SLEEF) {
488+
sleef_val = in_val.sleef_value;
489+
} else {
490+
sleef_val = Sleef_cast_from_doubleq1(in_val.longdouble_value);
491+
}
492+
493+
494+
PyObject *py_str;
495+
PyObject *positional_str = Dragon4_Positional_QuadDType(&sleef_val,
496+
DigitMode_Unique,
497+
CutoffMode_TotalLength,
498+
SLEEF_QUAD_DECIMAL_DIG,
499+
0,
500+
1,
501+
TrimMode_LeaveOneZero,
502+
1,
503+
0);
504+
505+
if (positional_str == NULL) {
506+
return -1;
507+
}
508+
509+
const char *pos_str = PyUnicode_AsUTF8(positional_str);
510+
if (pos_str == NULL) {
511+
Py_DECREF(positional_str);
512+
return -1;
513+
}
514+
515+
npy_intp pos_len = strlen(pos_str);
516+
517+
if (pos_len <= unicode_size_chars) {
518+
py_str = positional_str;
519+
} else {
520+
Py_DECREF(positional_str);
521+
// Use scientific notation with full precision
522+
py_str = Dragon4_Scientific_QuadDType(&sleef_val,
523+
DigitMode_Unique,
524+
SLEEF_QUAD_DECIMAL_DIG,
525+
0,
526+
1,
527+
TrimMode_LeaveOneZero,
528+
1,
529+
2);
530+
if (py_str == NULL) {
531+
return -1;
532+
}
533+
}
534+
535+
const char *temp_str = PyUnicode_AsUTF8(py_str);
536+
if (temp_str == NULL) {
537+
Py_DECREF(py_str);
538+
return -1;
539+
}
540+
541+
// Convert char string to UCS4 and store in output
542+
Py_UCS4* out_ucs4 = (Py_UCS4 *)out_ptr;
543+
npy_intp str_len = strlen(temp_str);
544+
545+
for(npy_intp i = 0; i < unicode_size_chars; i++)
546+
{
547+
if(i < str_len)
548+
{
549+
out_ucs4[i] = (Py_UCS4)temp_str[i];
550+
}
551+
else
552+
{
553+
out_ucs4[i] = 0;
554+
}
555+
}
556+
557+
Py_DECREF(py_str);
558+
559+
in_ptr += in_stride;
560+
out_ptr += out_stride;
561+
}
562+
563+
return 0;
564+
}
321565

322566
// Tag dispatching to ensure npy_bool/npy_ubyte and npy_half/npy_ushort do not alias in templates
323567
// see e.g. https://stackoverflow.com/q/32522279
@@ -1027,23 +1271,42 @@ init_casts_internal(void)
10271271
add_cast_from<long double>(&PyArray_LongDoubleDType);
10281272

10291273
// Unicode/String to QuadPrecision cast
1030-
PyArray_DTypeMeta **unicode_dtypes = new PyArray_DTypeMeta *[2]{&PyArray_UnicodeDType, &QuadPrecDType};
1031-
PyType_Slot *unicode_slots = new PyType_Slot[4]{
1274+
PyArray_DTypeMeta **unicode_to_quad_dtypes = new PyArray_DTypeMeta *[2]{&PyArray_UnicodeDType, &QuadPrecDType};
1275+
PyType_Slot *unicode_to_quad_slots = new PyType_Slot[4]{
10321276
{NPY_METH_resolve_descriptors, (void *)&unicode_to_quad_resolve_descriptors},
10331277
{NPY_METH_strided_loop, (void *)&unicode_to_quad_strided_loop_aligned},
10341278
{NPY_METH_unaligned_strided_loop, (void *)&unicode_to_quad_strided_loop_unaligned},
10351279
{0, nullptr}};
10361280

1037-
PyArrayMethod_Spec *unicode_spec = new PyArrayMethod_Spec{
1281+
PyArrayMethod_Spec *unicode_to_quad_spec = new PyArrayMethod_Spec{
10381282
.name = "cast_Unicode_to_QuadPrec",
10391283
.nin = 1,
10401284
.nout = 1,
10411285
.casting = NPY_UNSAFE_CASTING,
10421286
.flags = NPY_METH_SUPPORTS_UNALIGNED,
1043-
.dtypes = unicode_dtypes,
1044-
.slots = unicode_slots,
1287+
.dtypes = unicode_to_quad_dtypes,
1288+
.slots = unicode_to_quad_slots,
1289+
};
1290+
add_spec(unicode_to_quad_spec);
1291+
1292+
// QuadPrecision to Unicode
1293+
PyArray_DTypeMeta **quad_to_unicode_dtypes = new PyArray_DTypeMeta *[2]{&QuadPrecDType, &PyArray_UnicodeDType};
1294+
PyType_Slot *quad_to_unicode_slots = new PyType_Slot[4]{
1295+
{NPY_METH_resolve_descriptors, (void *)&quad_to_unicode_resolve_descriptors},
1296+
{NPY_METH_strided_loop, (void *)&quad_to_unicode_loop_aligned},
1297+
{NPY_METH_unaligned_strided_loop, (void *)&quad_to_unicode_loop_unaligned},
1298+
{0, nullptr}};
1299+
1300+
PyArrayMethod_Spec *quad_to_unicode_spec = new PyArrayMethod_Spec{
1301+
.name = "cast_QuadPrec_to_Unicode",
1302+
.nin = 1,
1303+
.nout = 1,
1304+
.casting = NPY_UNSAFE_CASTING,
1305+
.flags = NPY_METH_SUPPORTS_UNALIGNED,
1306+
.dtypes = quad_to_unicode_dtypes,
1307+
.slots = quad_to_unicode_slots,
10451308
};
1046-
add_spec(unicode_spec);
1309+
add_spec(quad_to_unicode_spec);
10471310

10481311
specs[spec_count] = nullptr;
10491312
return specs;

quaddtype/numpy_quaddtype/src/quad_common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ typedef union {
1919
long double longdouble_value;
2020
} quad_value;
2121

22+
23+
// For IEEE 754 binary128 (quad precision), we need 36 decimal digits
24+
// to guarantee round-trip conversion (string -> parse -> equals original value)
25+
// Formula: ceil(1 + MANT_DIG * log10(2)) = ceil(1 + 113 * 0.30103) = 36
26+
// src: https://en.wikipedia.org/wiki/Quadruple-precision_floating-point_format
27+
#define SLEEF_QUAD_DECIMAL_DIG 36
28+
2229
#ifdef __cplusplus
2330
}
2431
#endif

quaddtype/numpy_quaddtype/src/scalar.c

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,6 @@
1818
#include "lock.h"
1919
#include "utilities.h"
2020

21-
// For IEEE 754 binary128 (quad precision), we need 36 decimal digits
22-
// to guarantee round-trip conversion (string -> parse -> equals original value)
23-
// Formula: ceil(1 + MANT_DIG * log10(2)) = ceil(1 + 113 * 0.30103) = 36
24-
// src: https://en.wikipedia.org/wiki/Quadruple-precision_floating-point_format
25-
#define SLEEF_QUAD_DECIMAL_DIG 36
26-
2721

2822
QuadPrecisionObject *
2923
QuadPrecision_raw_new(QuadBackendType backend)

quaddtype/tests/test_quaddtype.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def test_supported_astype(dtype):
553553
# np.array(QuadPrecision(1)).astype(dtype, casting="unsafe")
554554

555555
class TestArrayCastStringBytes:
556-
@pytest.mark.parametrize("strtype", [np.str_, str, np.bytes_])
556+
@pytest.mark.parametrize("strtype", [np.str_, str])
557557
@pytest.mark.parametrize("input_val", [
558558
"3.141592653589793238462643383279502884197",
559559
"2.71828182845904523536028747135266249775",
@@ -566,14 +566,42 @@ class TestArrayCastStringBytes:
566566
"nan",
567567
"-nan",
568568
])
569-
def test_cast_string_to_quad(self, input_val, strtype):
569+
def test_cast_string_to_quad_roundtrip(self, input_val, strtype):
570+
# Test 1: String to Quad conversion
570571
str_array = np.array(input_val, dtype=strtype)
571572
quad_array = str_array.astype(QuadPrecDType())
572573
expected = np.array(input_val, dtype=QuadPrecDType())
574+
575+
# Verify string to quad conversion
573576
if np.isnan(float(expected)):
574577
np.testing.assert_array_equal(np.isnan(quad_array), np.isnan(expected))
575578
else:
576579
np.testing.assert_array_equal(quad_array, expected)
580+
581+
# Test 2: Quad to String conversion
582+
quad_to_string_array = quad_array.astype(strtype)
583+
584+
# Test 3: Round-trip - String -> Quad -> String -> Quad should preserve value
585+
roundtrip_quad_array = quad_to_string_array.astype(QuadPrecDType())
586+
587+
if np.isnan(float(expected)):
588+
# For NaN, just verify both are NaN
589+
np.testing.assert_array_equal(np.isnan(roundtrip_quad_array), np.isnan(quad_array))
590+
else:
591+
# For non-NaN values, the round-trip should preserve the exact value
592+
np.testing.assert_array_equal(roundtrip_quad_array, quad_array,
593+
err_msg=f"Round-trip failed for {input_val}")
594+
595+
# Test 4: Verify the string representation can be parsed back
596+
# (This ensures the quad->string cast produces valid parseable strings)
597+
scalar_str = str(quad_array[()])
598+
scalar_from_str = QuadPrecision(scalar_str)
599+
600+
if np.isnan(float(quad_array[()])):
601+
assert np.isnan(float(scalar_from_str))
602+
else:
603+
assert scalar_from_str == quad_array[()], \
604+
f"Scalar round-trip failed: {scalar_str} -> {scalar_from_str} != {quad_array[()]}"
577605

578606

579607
def test_basic_equality():

0 commit comments

Comments
 (0)