Skip to content

Commit e201b90

Browse files
committed
initial matmul ufunc setup
1 parent e467f4b commit e201b90

File tree

5 files changed

+254
-4
lines changed

5 files changed

+254
-4
lines changed

quaddtype/meson.build

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ srcs = [
6363
'numpy_quaddtype/src/umath/comparison_ops.h',
6464
'numpy_quaddtype/src/umath/comparison_ops.cpp',
6565
'numpy_quaddtype/src/umath/promoters.hpp',
66+
'numpy_quaddtype/src/umath/matmul.h',
67+
'numpy_quaddtype/src/umath/matmul.cpp',
6668
]
6769

6870
py.install_sources(

quaddtype/numpy_quaddtype/src/quadblas_interface.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,4 +786,95 @@ py_quadblas_get_version(PyObject *self, PyObject *args)
786786
return PyUnicode_FromString(QuadBLAS::VERSION);
787787
}
788788

789+
void matmul_op(Sleef_quad * inp1, Sleef_quad *inp2, Sleef_quad *out)
790+
{
791+
Sleef_quad *data_a, *data_b;
792+
QuadBackendType backend_a, backend_b;
793+
QuadBLAS::Layout layout_a, layout_b;
794+
795+
if (!extract_quad_array_info(a, &data_a, &backend_a, &layout_a) ||
796+
!extract_quad_array_info(b, &data_b, &backend_b, &layout_b)) {
797+
return nullptr;
798+
}
799+
800+
Sleef_quad *temp_a = nullptr, *temp_b = nullptr;
801+
Sleef_quad *sleef_a = ensure_sleef_backend(a, backend_a, &temp_a);
802+
Sleef_quad *sleef_b = ensure_sleef_backend(b, backend_b, &temp_b);
803+
804+
if (!sleef_a || !sleef_b) {
805+
QuadBLAS::aligned_free(temp_a);
806+
QuadBLAS::aligned_free(temp_b);
807+
return nullptr;
808+
}
809+
810+
QuadBackendType result_backend = BACKEND_SLEEF;
811+
if (backend_a == BACKEND_LONGDOUBLE && backend_b == BACKEND_LONGDOUBLE) {
812+
result_backend = BACKEND_LONGDOUBLE;
813+
}
814+
815+
npy_intp result_dims[2] = {m, n};
816+
QuadPrecDTypeObject *result_dtype = new_quaddtype_instance(result_backend);
817+
if (!result_dtype) {
818+
QuadBLAS::aligned_free(temp_a);
819+
QuadBLAS::aligned_free(temp_b);
820+
return nullptr;
821+
}
822+
823+
PyArrayObject *result =
824+
(PyArrayObject *)PyArray_Empty(2, result_dims, (PyArray_Descr *)result_dtype, 0);
825+
if (!result) {
826+
QuadBLAS::aligned_free(temp_a);
827+
QuadBLAS::aligned_free(temp_b);
828+
Py_DECREF(result_dtype);
829+
return nullptr;
830+
}
831+
832+
Sleef_quad *result_data = (Sleef_quad *)PyArray_DATA(result);
833+
for (npy_intp i = 0; i < m * n; i++) {
834+
result_data[i] = Sleef_cast_from_doubleq1(0.0);
835+
}
836+
837+
npy_intp lda, ldb, ldc;
838+
839+
if (layout_a == QuadBLAS::Layout::RowMajor) {
840+
lda = k;
841+
}
842+
else {
843+
lda = m;
844+
}
845+
846+
if (layout_b == QuadBLAS::Layout::RowMajor) {
847+
ldb = n;
848+
}
849+
else {
850+
ldb = k;
851+
}
852+
853+
QuadBLAS::Layout result_layout = layout_a;
854+
if (result_layout == QuadBLAS::Layout::RowMajor) {
855+
ldc = n;
856+
}
857+
else {
858+
ldc = m;
859+
}
860+
861+
Sleef_quad alpha = Sleef_cast_from_doubleq1(1.0);
862+
Sleef_quad beta = Sleef_cast_from_doubleq1(0.0);
863+
864+
QuadBLAS::gemm(result_layout, m, n, k, alpha, sleef_a, lda, sleef_b, ldb, beta, result_data,
865+
ldc);
866+
867+
if (result_backend == BACKEND_LONGDOUBLE) {
868+
long double *ld_result = (long double *)PyArray_DATA(result);
869+
for (npy_intp i = 0; i < m * n; i++) {
870+
ld_result[i] = (long double)Sleef_cast_to_doubleq1(result_data[i]);
871+
}
872+
}
873+
874+
QuadBLAS::aligned_free(temp_a);
875+
QuadBLAS::aligned_free(temp_b);
876+
877+
return (PyObject *)result;
878+
}
879+
789880
#endif // DISABLE_QUADBLAS
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#define PY_ARRAY_UNIQUE_SYMBOL QuadPrecType_ARRAY_API
2+
#define PY_UFUNC_UNIQUE_SYMBOL QuadPrecType_UFUNC_API
3+
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
4+
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
5+
#define NO_IMPORT_ARRAY
6+
#define NO_IMPORT_UFUNC
7+
8+
9+
#include <Python.h>
10+
#include <cstdio>
11+
12+
#include "numpy/arrayobject.h"
13+
#include "numpy/ufuncobject.h"
14+
#include "numpy/dtype_api.h"
15+
#include "numpy/ndarraytypes.h"
16+
17+
#include "../quad_common.h"
18+
#include "../scalar.h"
19+
#include "../dtype.h"
20+
#include "../ops.hpp"
21+
#include "binary_ops.h"
22+
#include "matmul.h"
23+
24+
#include <iostream>
25+
26+
static NPY_CASTING
27+
quad_matmul_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[],
28+
PyArray_Descr *const given_descrs[],
29+
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
30+
{
31+
32+
NPY_CASTING casting = NPY_NO_CASTING;
33+
std::cout << "exiting the descriptor";
34+
return casting;
35+
}
36+
37+
template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
38+
int
39+
quad_generic_matmul_strided_loop_unaligned(PyArrayMethod_Context *context, char *const data[],
40+
npy_intp const dimensions[], npy_intp const strides[],
41+
NpyAuxData *auxdata)
42+
{
43+
npy_intp N = dimensions[0];
44+
char *in1_ptr = data[0], *in2_ptr = data[1];
45+
char *out_ptr = data[2];
46+
npy_intp in1_stride = strides[0];
47+
npy_intp in2_stride = strides[1];
48+
npy_intp out_stride = strides[2];
49+
50+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
51+
QuadBackendType backend = descr->backend;
52+
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
53+
54+
quad_value in1, in2, out;
55+
while (N--) {
56+
memcpy(&in1, in1_ptr, elem_size);
57+
memcpy(&in2, in2_ptr, elem_size);
58+
if (backend == BACKEND_SLEEF) {
59+
out.sleef_value = sleef_op(&in1.sleef_value, &in2.sleef_value);
60+
}
61+
else {
62+
out.longdouble_value = longdouble_op(&in1.longdouble_value, &in2.longdouble_value);
63+
}
64+
memcpy(out_ptr, &out, elem_size);
65+
66+
in1_ptr += in1_stride;
67+
in2_ptr += in2_stride;
68+
out_ptr += out_stride;
69+
}
70+
return 0;
71+
}
72+
73+
template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
74+
int
75+
quad_generic_matmul_strided_loop_aligned(PyArrayMethod_Context *context, char *const data[],
76+
npy_intp const dimensions[], npy_intp const strides[],
77+
NpyAuxData *auxdata)
78+
{
79+
npy_intp N = dimensions[0];
80+
char *in1_ptr = data[0], *in2_ptr = data[1];
81+
char *out_ptr = data[2];
82+
npy_intp in1_stride = strides[0];
83+
npy_intp in2_stride = strides[1];
84+
npy_intp out_stride = strides[2];
85+
86+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
87+
QuadBackendType backend = descr->backend;
88+
89+
while (N--) {
90+
if (backend == BACKEND_SLEEF) {
91+
*(Sleef_quad *)out_ptr = sleef_op((Sleef_quad *)in1_ptr, (Sleef_quad *)in2_ptr);
92+
}
93+
else {
94+
*(long double *)out_ptr = longdouble_op((long double *)in1_ptr, (long double *)in2_ptr);
95+
}
96+
97+
in1_ptr += in1_stride;
98+
in2_ptr += in2_stride;
99+
out_ptr += out_stride;
100+
}
101+
return 0;
102+
}
103+
104+
int
105+
create_matmul_ufunc(PyObject *numpy, const char *ufunc_name)
106+
{
107+
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
108+
if (ufunc == NULL) {
109+
return -1;
110+
}
111+
112+
PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
113+
114+
PyType_Slot slots[] = {
115+
{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
116+
{NPY_METH_strided_loop,
117+
(void *)&quad_generic_matmul_strided_loop_aligned<sleef_op, longdouble_op>},
118+
{NPY_METH_unaligned_strided_loop,
119+
(void *)&quad_generic_matmul_strided_loop_unaligned<sleef_op, longdouble_op>},
120+
{0, NULL}};
121+
122+
PyArrayMethod_Spec Spec = {
123+
.name = "quad_matmul",
124+
.nin = 2,
125+
.nout = 1,
126+
.casting = NPY_NO_CASTING,
127+
.flags = (NPY_ARRAYMETHOD_FLAGS)(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_IS_REORDERABLE),
128+
.dtypes = dtypes,
129+
.slots = slots,
130+
};
131+
132+
if (PyUFunc_AddLoopFromSpec(ufunc, &Spec) < 0) {
133+
return -1;
134+
}
135+
// my guess we don't need any promoter here as of now, since matmul is quad specific
136+
return 0;
137+
}
138+
139+
140+
int
141+
init_matmul_ops(PyObject *numpy)
142+
{
143+
if (create_matmul_ufunc<quad_add>(numpy, "matmul") < 0) {
144+
return -1;
145+
}
146+
return 0;
147+
}
148+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef _QUADDTYPE_MATMUL_OPS_H
2+
#define _QUADDTYPE_MATMUL_OPS_H
3+
4+
#include <Python.h>
5+
6+
int
7+
init_matmul_ops(PyObject *numpy);
8+
#endif

quaddtype/numpy_quaddtype/src/umath/umath.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ extern "C" {
2222
#include "unary_ops.h"
2323
#include "binary_ops.h"
2424
#include "comparison_ops.h"
25+
#include "matmul.h"
2526

2627
// helper debugging function
2728
static const char *
@@ -101,10 +102,10 @@ init_quad_umath(void)
101102
goto err;
102103
}
103104

104-
// if (init_quad_matmul(numpy) < 0) {
105-
// PyErr_SetString(PyExc_RuntimeError, "Failed to initialize quad matrix multiplication operations");
106-
// goto err;
107-
// }
105+
if (init_matmul_ops(numpy) < 0) {
106+
PyErr_SetString(PyExc_RuntimeError, "Failed to initialize quad matrix multiplication operations");
107+
goto err;
108+
}
108109

109110
Py_DECREF(numpy);
110111
return 0;

0 commit comments

Comments
 (0)