Skip to content

Commit 1c26cbf

Browse files
committed
umath refactor complete
1 parent 2a81075 commit 1c26cbf

File tree

17 files changed

+865
-532
lines changed

17 files changed

+865
-532
lines changed

.github/workflows/build_wheels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Build Wheels
33
on:
44
push:
55
branches:
6-
- main
6+
- matmul-ufunc
77
tags:
88
- "quaddtype-v*"
99
paths:

quaddtype/README.md

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,16 @@ cmake --build build/ --clean-first -j
3939
cd ..
4040
```
4141

42-
Building the `numpy-quaddtype` package from locally installed sleef:
42+
Building the `numpy-quaddtype` package with [QBLAS](https://github.com/SwayamInSync/QBLAS) from locally installed sleef:
43+
44+
> Currently QBLAS is only supported with GCC and Clang compilers and incompatible with MSVC.
45+
4346
```bash
4447
export SLEEF_DIR=$PWD/sleef/build
4548
export LIBRARY_PATH=$SLEEF_DIR/lib
46-
export C_INCLUDE_PATH=$SLEEF_DIR/include
47-
export CPLUS_INCLUDE_PATH=$SLEEF_DIR/include
49+
export QBLAS_DIR=$PWD/numpy_quaddtype/QBLAS
50+
export C_INCLUDE_PATH=$SLEEF_DIR/include:$QBLAS_DIR/include
51+
export CPLUS_INCLUDE_PATH=$SLEEF_DIR/include:$QBLAS_DIR/include
4852

4953
# setup the virtual env
5054
python3 -m venv temp
@@ -57,10 +61,17 @@ export LDFLAGS="-Wl,-rpath,$SLEEF_DIR/lib -fopenmp -latomic -lpthread"
5761
export CFLAGS="-fPIC"
5862
export CXXFLAGS="-fPIC"
5963

64+
# Disable QBLAS for MSVC builds (fallbacks to naive implementations of linear algebra operations)
65+
# export CFLAGS="-fPIC -DDISABLE_QUADBLAS"
66+
# export CXXFLAGS="-fPIC -DDISABLE_QUADBLAS"
67+
6068
python -m pip install . -v --no-build-isolation -Cbuilddir=build -C'compile-args=-v'
6169

6270
# Run the tests
6371
cd ..
6472
python -m pytest
6573
```
6674

75+
76+
77+

quaddtype/meson.build

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,19 @@ srcs = [
5050
'numpy_quaddtype/src/scalar_ops.h',
5151
'numpy_quaddtype/src/scalar_ops.cpp',
5252
'numpy_quaddtype/src/ops.hpp',
53-
'numpy_quaddtype/src/umath.h',
54-
'numpy_quaddtype/src/umath.cpp',
5553
'numpy_quaddtype/src/dragon4.h',
5654
'numpy_quaddtype/src/dragon4.c',
5755
'numpy_quaddtype/src/quadblas_interface.h',
58-
'numpy_quaddtype/src/quadblas_interface.cpp'
56+
'numpy_quaddtype/src/quadblas_interface.cpp',
57+
'numpy_quaddtype/src/umath/umath.h',
58+
'numpy_quaddtype/src/umath/umath.cpp',
59+
'numpy_quaddtype/src/umath/binary_ops.h',
60+
'numpy_quaddtype/src/umath/binary_ops.cpp',
61+
'numpy_quaddtype/src/umath/unary_ops.h',
62+
'numpy_quaddtype/src/umath/unary_ops.cpp',
63+
'numpy_quaddtype/src/umath/comparison_ops.h',
64+
'numpy_quaddtype/src/umath/comparison_ops.cpp',
65+
'numpy_quaddtype/src/umath/promoters.hpp',
5966
]
6067

6168
py.install_sources(

quaddtype/numpy_quaddtype/src/quadblas_interface.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,20 @@
55
#define NO_IMPORT_ARRAY
66
#define NO_IMPORT_UFUNC
77

8-
extern "C" {
8+
99
#include <Python.h>
1010
#include "numpy/arrayobject.h"
1111
#include "numpy/ndarraytypes.h"
1212
#include "numpy/dtype_api.h"
13-
}
1413

1514
#include "scalar.h"
1615
#include "dtype.h"
1716
#include "quad_common.h"
1817
#include "quadblas_interface.h"
1918

20-
extern "C" {
2119
#include <sleef.h>
2220
#include <sleefquad.h>
23-
}
21+
2422

2523
#ifndef DISABLE_QUADBLAS
2624
#include "../QBLAS/include/quadblas/quadblas.hpp"
Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
#ifndef _QUADDTYPE_QUADBLAS_INTERFACE_H
22
#define _QUADDTYPE_QUADBLAS_INTERFACE_H
33

4-
#ifdef __cplusplus
5-
extern "C" {
6-
#endif
7-
84
#include <Python.h>
95

106

@@ -16,8 +12,6 @@ PyObject* py_quadblas_get_num_threads(PyObject* self, PyObject* args);
1612

1713
PyObject* py_quadblas_get_version(PyObject* self, PyObject* args);
1814

19-
#ifdef __cplusplus
20-
}
21-
#endif
15+
2216

2317
#endif

quaddtype/numpy_quaddtype/src/quaddtype_main.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#include "scalar.h"
1616
#include "dtype.h"
17-
#include "umath.h"
17+
#include "umath/umath.h"
1818
#include "quad_common.h"
1919
#include "quadblas_interface.h"
2020
#include "float.h"
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
#define PY_ARRAY_UNIQUE_SYMBOL QuadPrecType_ARRAY_API
2+
#define PY_UFUNC_UNIQUE_SYMBOL QuadPrecType_UFUNC_API
3+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
4+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
5+
#define NO_IMPORT_ARRAY
6+
#define NO_IMPORT_UFUNC
7+
8+
9+
#include <Python.h>
10+
#include <cstdio>
11+
12+
#include "numpy/arrayobject.h"
13+
#include "numpy/ufuncobject.h"
14+
#include "numpy/dtype_api.h"
15+
#include "numpy/ndarraytypes.h"
16+
17+
#include "../quad_common.h"
18+
#include "../scalar.h"
19+
#include "../dtype.h"
20+
#include "../ops.hpp"
21+
#include "promoters.hpp"
22+
#include "binary_ops.h"
23+
24+
static NPY_CASTING
25+
quad_binary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[],
26+
PyArray_Descr *const given_descrs[],
27+
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
28+
{
29+
QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[0];
30+
QuadPrecDTypeObject *descr_in2 = (QuadPrecDTypeObject *)given_descrs[1];
31+
QuadBackendType target_backend;
32+
33+
// Determine target backend and if casting is needed
34+
NPY_CASTING casting = NPY_NO_CASTING;
35+
if (descr_in1->backend != descr_in2->backend) {
36+
target_backend = BACKEND_LONGDOUBLE;
37+
casting = NPY_SAFE_CASTING;
38+
}
39+
else {
40+
target_backend = descr_in1->backend;
41+
}
42+
43+
// Set up input descriptors, casting if necessary
44+
for (int i = 0; i < 2; i++) {
45+
if (((QuadPrecDTypeObject *)given_descrs[i])->backend != target_backend) {
46+
loop_descrs[i] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
47+
if (!loop_descrs[i]) {
48+
return (NPY_CASTING)-1;
49+
}
50+
}
51+
else {
52+
Py_INCREF(given_descrs[i]);
53+
loop_descrs[i] = given_descrs[i];
54+
}
55+
}
56+
57+
// Set up output descriptor
58+
if (given_descrs[2] == NULL) {
59+
loop_descrs[2] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
60+
if (!loop_descrs[2]) {
61+
return (NPY_CASTING)-1;
62+
}
63+
}
64+
else {
65+
QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)given_descrs[2];
66+
if (descr_out->backend != target_backend) {
67+
loop_descrs[2] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
68+
if (!loop_descrs[2]) {
69+
return (NPY_CASTING)-1;
70+
}
71+
}
72+
else {
73+
Py_INCREF(given_descrs[2]);
74+
loop_descrs[2] = given_descrs[2];
75+
}
76+
}
77+
return casting;
78+
}
79+
80+
template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
81+
int
82+
quad_generic_binop_strided_loop_unaligned(PyArrayMethod_Context *context, char *const data[],
83+
npy_intp const dimensions[], npy_intp const strides[],
84+
NpyAuxData *auxdata)
85+
{
86+
npy_intp N = dimensions[0];
87+
char *in1_ptr = data[0], *in2_ptr = data[1];
88+
char *out_ptr = data[2];
89+
npy_intp in1_stride = strides[0];
90+
npy_intp in2_stride = strides[1];
91+
npy_intp out_stride = strides[2];
92+
93+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
94+
QuadBackendType backend = descr->backend;
95+
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
96+
97+
quad_value in1, in2, out;
98+
while (N--) {
99+
memcpy(&in1, in1_ptr, elem_size);
100+
memcpy(&in2, in2_ptr, elem_size);
101+
if (backend == BACKEND_SLEEF) {
102+
out.sleef_value = sleef_op(&in1.sleef_value, &in2.sleef_value);
103+
}
104+
else {
105+
out.longdouble_value = longdouble_op(&in1.longdouble_value, &in2.longdouble_value);
106+
}
107+
memcpy(out_ptr, &out, elem_size);
108+
109+
in1_ptr += in1_stride;
110+
in2_ptr += in2_stride;
111+
out_ptr += out_stride;
112+
}
113+
return 0;
114+
}
115+
116+
template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
117+
int
118+
quad_generic_binop_strided_loop_aligned(PyArrayMethod_Context *context, char *const data[],
119+
npy_intp const dimensions[], npy_intp const strides[],
120+
NpyAuxData *auxdata)
121+
{
122+
npy_intp N = dimensions[0];
123+
char *in1_ptr = data[0], *in2_ptr = data[1];
124+
char *out_ptr = data[2];
125+
npy_intp in1_stride = strides[0];
126+
npy_intp in2_stride = strides[1];
127+
npy_intp out_stride = strides[2];
128+
129+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
130+
QuadBackendType backend = descr->backend;
131+
132+
while (N--) {
133+
if (backend == BACKEND_SLEEF) {
134+
*(Sleef_quad *)out_ptr = sleef_op((Sleef_quad *)in1_ptr, (Sleef_quad *)in2_ptr);
135+
}
136+
else {
137+
*(long double *)out_ptr = longdouble_op((long double *)in1_ptr, (long double *)in2_ptr);
138+
}
139+
140+
in1_ptr += in1_stride;
141+
in2_ptr += in2_stride;
142+
out_ptr += out_stride;
143+
}
144+
return 0;
145+
}
146+
147+
148+
template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
149+
int
150+
create_quad_binary_ufunc(PyObject *numpy, const char *ufunc_name)
151+
{
152+
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
153+
if (ufunc == NULL) {
154+
return -1;
155+
}
156+
157+
PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
158+
159+
PyType_Slot slots[] = {
160+
{NPY_METH_resolve_descriptors, (void *)&quad_binary_op_resolve_descriptors},
161+
{NPY_METH_strided_loop,
162+
(void *)&quad_generic_binop_strided_loop_aligned<sleef_op, longdouble_op>},
163+
{NPY_METH_unaligned_strided_loop,
164+
(void *)&quad_generic_binop_strided_loop_unaligned<sleef_op, longdouble_op>},
165+
{0, NULL}};
166+
167+
PyArrayMethod_Spec Spec = {
168+
.name = "quad_binop",
169+
.nin = 2,
170+
.nout = 1,
171+
.casting = NPY_NO_CASTING,
172+
.flags = (NPY_ARRAYMETHOD_FLAGS)(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_IS_REORDERABLE),
173+
.dtypes = dtypes,
174+
.slots = slots,
175+
};
176+
177+
if (PyUFunc_AddLoopFromSpec(ufunc, &Spec) < 0) {
178+
return -1;
179+
}
180+
181+
PyObject *promoter_capsule =
182+
PyCapsule_New((void *)&quad_ufunc_promoter, "numpy._ufunc_promoter", NULL);
183+
if (promoter_capsule == NULL) {
184+
return -1;
185+
}
186+
187+
PyObject *DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArrayDescr_Type);
188+
if (DTypes == 0) {
189+
Py_DECREF(promoter_capsule);
190+
return -1;
191+
}
192+
193+
if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
194+
Py_DECREF(promoter_capsule);
195+
Py_DECREF(DTypes);
196+
return -1;
197+
}
198+
Py_DECREF(promoter_capsule);
199+
Py_DECREF(DTypes);
200+
return 0;
201+
}
202+
203+
204+
int
205+
init_quad_binary_ops(PyObject *numpy)
206+
{
207+
if (create_quad_binary_ufunc<quad_add, ld_add>(numpy, "add") < 0) {
208+
return -1;
209+
}
210+
if (create_quad_binary_ufunc<quad_sub, ld_sub>(numpy, "subtract") < 0) {
211+
return -1;
212+
}
213+
if (create_quad_binary_ufunc<quad_mul, ld_mul>(numpy, "multiply") < 0) {
214+
return -1;
215+
}
216+
if (create_quad_binary_ufunc<quad_div, ld_div>(numpy, "divide") < 0) {
217+
return -1;
218+
}
219+
if (create_quad_binary_ufunc<quad_pow, ld_pow>(numpy, "power") < 0) {
220+
return -1;
221+
}
222+
if (create_quad_binary_ufunc<quad_mod, ld_mod>(numpy, "mod") < 0) {
223+
return -1;
224+
}
225+
if (create_quad_binary_ufunc<quad_minimum, ld_minimum>(numpy, "minimum") < 0) {
226+
return -1;
227+
}
228+
if (create_quad_binary_ufunc<quad_maximum, ld_maximum>(numpy, "maximum") < 0) {
229+
return -1;
230+
}
231+
if (create_quad_binary_ufunc<quad_atan2, ld_atan2>(numpy, "arctan2") < 0) {
232+
return -1;
233+
}
234+
return 0;
235+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#ifndef _QUADDTYPE_BINARY_OPS_H
2+
#define _QUADDTYPE_BINARY_OPS_H
3+
4+
#include <Python.h>
5+
6+
int
7+
init_quad_binary_ops(PyObject *numpy);
8+
9+
#endif

0 commit comments

Comments
 (0)