Skip to content

Commit 6668883

Browse files
committed
implementing fexpr ufunc
1 parent 0397aa2 commit 6668883

File tree

6 files changed

+208
-2
lines changed

6 files changed

+208
-2
lines changed

quaddtype/meson.build

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ incdir_numpy = run_command(py,
2424
check : true
2525
).stdout().strip()
2626

27+
# print the version of numpy being used
28+
numpy_version = run_command(py,
29+
['-c', 'import numpy; print(numpy.__version__)'],
30+
check : true
31+
).stdout().strip()
32+
message('Using NumPy version: ' + numpy_version)
33+
2734
npymath_path = incdir_numpy / '..' / 'lib'
2835
npymath_lib = c.find_library('npymath', dirs: npymath_path)
2936

@@ -117,6 +124,8 @@ srcs = [
117124
'numpy_quaddtype/src/umath/promoters.hpp',
118125
'numpy_quaddtype/src/umath/matmul.h',
119126
'numpy_quaddtype/src/umath/matmul.cpp',
127+
'numpy_quaddtype/src/umath/frexp_op.h',
128+
'numpy_quaddtype/src/umath/frexp_op.cpp',
120129
]
121130

122131
py.install_sources(

quaddtype/numpy_quaddtype/src/dtype.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ static PyType_Slot QuadPrecDType_Slots[] = {
184184
{NPY_DT_setitem, &quadprec_setitem},
185185
{NPY_DT_getitem, &quadprec_getitem},
186186
{NPY_DT_default_descr, &quadprec_default_descr},
187-
{NPY_DT_PyArray_ArrFuncs_dotfunc, NULL},
188187
{0, NULL}};
189188

190189
static PyObject *

quaddtype/numpy_quaddtype/src/scalar.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ QuadPrecision_dealloc(QuadPrecisionObject *self)
239239
}
240240

241241
PyTypeObject QuadPrecision_Type = {
242-
PyVarObject_HEAD_INIT(NULL, 0).tp_name = "numpy_quaddtype.QuadPrecision",
242+
PyVarObject_HEAD_INIT(NULL, 0).tp_name = "numpy_quaddtype.QuadPrecDType",
243243
.tp_basicsize = sizeof(QuadPrecisionObject),
244244
.tp_itemsize = 0,
245245
.tp_new = QuadPrecision_new,
@@ -253,5 +253,6 @@ PyTypeObject QuadPrecision_Type = {
253253
int
254254
init_quadprecision_scalar(void)
255255
{
256+
QuadPrecision_Type.tp_base = &PyFloatingArrType_Type;
256257
return PyType_Ready(&QuadPrecision_Type);
257258
}
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#define PY_ARRAY_UNIQUE_SYMBOL QuadPrecType_ARRAY_API
2+
#define PY_UFUNC_UNIQUE_SYMBOL QuadPrecType_UFUNC_API
3+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
4+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
5+
#define NO_IMPORT_ARRAY
6+
#define NO_IMPORT_UFUNC
7+
8+
extern "C" {
9+
#include <Python.h>
10+
#include <cstdio>
11+
12+
#include "numpy/arrayobject.h"
13+
#include "numpy/ndarraytypes.h"
14+
#include "numpy/ufuncobject.h"
15+
#include "numpy/dtype_api.h"
16+
}
17+
#include "../quad_common.h"
18+
#include "../scalar.h"
19+
#include "../dtype.h"
20+
#include "../ops.hpp"
21+
22+
// Forward declarations for frexp operations
23+
static Sleef_quad quad_frexp_mantissa(const Sleef_quad *op, int *exp);
24+
static long double ld_frexp_mantissa(const long double *op, int *exp);
25+
26+
static Sleef_quad
27+
quad_frexp_mantissa(const Sleef_quad *op, int *exp)
28+
{
29+
return Sleef_frexpq1(*op, exp);
30+
}
31+
32+
static long double
33+
ld_frexp_mantissa(const long double *op, int *exp)
34+
{
35+
return frexpl(*op, exp);
36+
}
37+
38+
static NPY_CASTING
39+
quad_frexp_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[],
40+
PyArray_Descr *const given_descrs[], PyArray_Descr *loop_descrs[],
41+
npy_intp *NPY_UNUSED(view_offset))
42+
{
43+
Py_INCREF(given_descrs[0]);
44+
loop_descrs[0] = given_descrs[0];
45+
46+
// Output 1: QuadPrecDType (mantissa)
47+
if (given_descrs[1] == NULL) {
48+
Py_INCREF(given_descrs[0]);
49+
loop_descrs[1] = given_descrs[0];
50+
}
51+
else {
52+
Py_INCREF(given_descrs[1]);
53+
loop_descrs[1] = given_descrs[1];
54+
}
55+
56+
// Output 2: Int32 (exponent)
57+
if (given_descrs[2] == NULL) {
58+
loop_descrs[2] = PyArray_DescrFromType(NPY_INT32);
59+
}
60+
else {
61+
Py_INCREF(given_descrs[2]);
62+
loop_descrs[2] = given_descrs[2];
63+
}
64+
65+
return NPY_NO_CASTING;
66+
}
67+
68+
static int
69+
quad_frexp_strided_loop_unaligned(PyArrayMethod_Context *context, char *const data[],
70+
npy_intp const dimensions[], npy_intp const strides[],
71+
NpyAuxData *auxdata)
72+
{
73+
npy_intp N = dimensions[0];
74+
char *in_ptr = data[0];
75+
char *mantissa_ptr = data[1];
76+
char *exp_ptr = data[2];
77+
npy_intp in_stride = strides[0];
78+
npy_intp mantissa_stride = strides[1];
79+
npy_intp exp_stride = strides[2];
80+
81+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
82+
QuadBackendType backend = descr->backend;
83+
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
84+
85+
quad_value in, mantissa;
86+
int exp;
87+
while (N--) {
88+
memcpy(&in, in_ptr, elem_size);
89+
90+
if (backend == BACKEND_SLEEF) {
91+
mantissa.sleef_value = quad_frexp_mantissa(&in.sleef_value, &exp);
92+
}
93+
else {
94+
mantissa.longdouble_value = ld_frexp_mantissa(&in.longdouble_value, &exp);
95+
}
96+
97+
memcpy(mantissa_ptr, &mantissa, elem_size);
98+
*(npy_int32 *)exp_ptr = (npy_int32)exp;
99+
100+
in_ptr += in_stride;
101+
mantissa_ptr += mantissa_stride;
102+
exp_ptr += exp_stride;
103+
}
104+
return 0;
105+
}
106+
107+
static int
108+
quad_frexp_strided_loop_aligned(PyArrayMethod_Context *context, char *const data[],
109+
npy_intp const dimensions[], npy_intp const strides[],
110+
NpyAuxData *auxdata)
111+
{
112+
npy_intp N = dimensions[0];
113+
char *in_ptr = data[0];
114+
char *mantissa_ptr = data[1];
115+
char *exp_ptr = data[2];
116+
npy_intp in_stride = strides[0];
117+
npy_intp mantissa_stride = strides[1];
118+
npy_intp exp_stride = strides[2];
119+
120+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
121+
QuadBackendType backend = descr->backend;
122+
123+
int exp;
124+
while (N--) {
125+
if (backend == BACKEND_SLEEF) {
126+
*(Sleef_quad *)mantissa_ptr = quad_frexp_mantissa((Sleef_quad *)in_ptr, &exp);
127+
}
128+
else {
129+
*(long double *)mantissa_ptr = ld_frexp_mantissa((long double *)in_ptr, &exp);
130+
}
131+
132+
*(npy_int32 *)exp_ptr = (npy_int32)exp;
133+
134+
in_ptr += in_stride;
135+
mantissa_ptr += mantissa_stride;
136+
exp_ptr += exp_stride;
137+
}
138+
return 0;
139+
}
140+
141+
int
142+
create_quad_frexp_ufunc(PyObject *numpy)
143+
{
144+
PyObject *ufunc = PyObject_GetAttrString(numpy, "frexp");
145+
if (ufunc == NULL) {
146+
return -1;
147+
}
148+
149+
PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &QuadPrecDType, &PyArray_Int32DType};
150+
151+
PyType_Slot slots[] = {
152+
{NPY_METH_resolve_descriptors, (void *)&quad_frexp_resolve_descriptors},
153+
{NPY_METH_strided_loop, (void *)&quad_frexp_strided_loop_aligned},
154+
{NPY_METH_unaligned_strided_loop, (void *)&quad_frexp_strided_loop_unaligned},
155+
{0, NULL}};
156+
157+
PyArrayMethod_Spec Spec = {
158+
.name = "quad_frexp",
159+
.nin = 1,
160+
.nout = 2,
161+
.casting = NPY_NO_CASTING,
162+
.flags = NPY_METH_SUPPORTS_UNALIGNED,
163+
.dtypes = dtypes,
164+
.slots = slots,
165+
};
166+
167+
if (PyUFunc_AddLoopFromSpec(ufunc, &Spec) < 0) {
168+
return -1;
169+
}
170+
171+
Py_DECREF(ufunc);
172+
return 0;
173+
}
174+
175+
int
176+
init_quad_frexp(PyObject *numpy)
177+
{
178+
if (create_quad_frexp_ufunc(numpy) < 0) {
179+
return -1;
180+
}
181+
return 0;
182+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#ifndef _QUADDTYPE_FREXP_OP_H
2+
#define _QUADDTYPE_FREXP_OP_H
3+
4+
#include <Python.h>
5+
6+
int
7+
init_quad_frexp(PyObject *numpy);
8+
9+
#endif

quaddtype/numpy_quaddtype/src/umath/umath.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ extern "C" {
2424
#include "binary_ops.h"
2525
#include "comparison_ops.h"
2626
#include "matmul.h"
27+
#include "frexp_op.h"
2728

2829
// helper debugging function
2930
static const char *
@@ -113,6 +114,11 @@ init_quad_umath(void)
113114
goto err;
114115
}
115116

117+
if (init_quad_frexp(numpy) < 0) {
118+
PyErr_SetString(PyExc_RuntimeError, "Failed to initialize quad frexp operation");
119+
goto err;
120+
}
121+
116122
Py_DECREF(numpy);
117123
return 0;
118124

0 commit comments

Comments
 (0)