Skip to content

Commit 894a84d

Browse files
committed
matmul registered with naive
1 parent f89c2e6 commit 894a84d

File tree

3 files changed

+225
-83
lines changed

3 files changed

+225
-83
lines changed

quaddtype/numpy_quaddtype/src/umath/matmul.cpp

Lines changed: 187 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -5,141 +5,251 @@
55
#define NO_IMPORT_ARRAY
66
#define NO_IMPORT_UFUNC
77

8+
extern "C" {
89
#include <Python.h>
910
#include <cstdio>
11+
#include <string.h>
1012

1113
#include "numpy/arrayobject.h"
14+
#include "numpy/ndarraytypes.h"
1215
#include "numpy/ufuncobject.h"
1316
#include "numpy/dtype_api.h"
14-
#include "numpy/ndarraytypes.h"
17+
}
1518

1619
#include "../quad_common.h"
1720
#include "../scalar.h"
1821
#include "../dtype.h"
1922
#include "../ops.hpp"
20-
#include "binary_ops.h"
2123
#include "matmul.h"
24+
#include "promoters.hpp"
2225

23-
#include <iostream>
24-
26+
/**
27+
* Resolve descriptors for matmul operation.
28+
* Follows the same pattern as binary_ops.cpp
29+
*/
2530
static NPY_CASTING
2631
quad_matmul_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[],
2732
PyArray_Descr *const given_descrs[], PyArray_Descr *loop_descrs[],
2833
npy_intp *NPY_UNUSED(view_offset))
2934
{
30-
NPY_CASTING casting = NPY_NO_CASTING;
31-
std::cout << "exiting the descriptor";
32-
return casting;
33-
}
35+
// Follow the exact same pattern as quad_binary_op_resolve_descriptors
36+
QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[0];
37+
QuadPrecDTypeObject *descr_in2 = (QuadPrecDTypeObject *)given_descrs[1];
38+
QuadBackendType target_backend;
3439

35-
template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
36-
int
37-
quad_generic_matmul_strided_loop_unaligned(PyArrayMethod_Context *context, char *const data[],
38-
npy_intp const dimensions[], npy_intp const strides[],
39-
NpyAuxData *auxdata)
40-
{
41-
npy_intp N = dimensions[0];
42-
char *in1_ptr = data[0], *in2_ptr = data[1];
43-
char *out_ptr = data[2];
44-
npy_intp in1_stride = strides[0];
45-
npy_intp in2_stride = strides[1];
46-
npy_intp out_stride = strides[2];
47-
48-
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
49-
QuadBackendType backend = descr->backend;
50-
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
40+
// Determine target backend and if casting is needed
41+
NPY_CASTING casting = NPY_NO_CASTING;
42+
if (descr_in1->backend != descr_in2->backend) {
43+
target_backend = BACKEND_LONGDOUBLE;
44+
casting = NPY_SAFE_CASTING;
45+
}
46+
else {
47+
target_backend = descr_in1->backend;
48+
}
5149

52-
quad_value in1, in2, out;
53-
while (N--) {
54-
memcpy(&in1, in1_ptr, elem_size);
55-
memcpy(&in2, in2_ptr, elem_size);
56-
if (backend == BACKEND_SLEEF) {
57-
out.sleef_value = sleef_op(&in1.sleef_value, &in2.sleef_value);
50+
// Set up input descriptors, casting if necessary
51+
for (int i = 0; i < 2; i++) {
52+
if (((QuadPrecDTypeObject *)given_descrs[i])->backend != target_backend) {
53+
loop_descrs[i] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
54+
if (!loop_descrs[i]) {
55+
return (NPY_CASTING)-1;
56+
}
5857
}
5958
else {
60-
out.longdouble_value = longdouble_op(&in1.longdouble_value, &in2.longdouble_value);
59+
Py_INCREF(given_descrs[i]);
60+
loop_descrs[i] = given_descrs[i];
6161
}
62-
memcpy(out_ptr, &out, elem_size);
62+
}
6363

64-
in1_ptr += in1_stride;
65-
in2_ptr += in2_stride;
66-
out_ptr += out_stride;
64+
// Set up output descriptor
65+
if (given_descrs[2] == NULL) {
66+
loop_descrs[2] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
67+
if (!loop_descrs[2]) {
68+
return (NPY_CASTING)-1;
69+
}
6770
}
68-
return 0;
71+
else {
72+
QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)given_descrs[2];
73+
if (descr_out->backend != target_backend) {
74+
loop_descrs[2] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
75+
if (!loop_descrs[2]) {
76+
return (NPY_CASTING)-1;
77+
}
78+
}
79+
else {
80+
Py_INCREF(given_descrs[2]);
81+
loop_descrs[2] = given_descrs[2];
82+
}
83+
}
84+
return casting;
6985
}
7086

71-
template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
72-
int
73-
quad_generic_matmul_strided_loop_aligned(PyArrayMethod_Context *context, char *const data[],
74-
npy_intp const dimensions[], npy_intp const strides[],
75-
NpyAuxData *auxdata)
87+
/**
88+
* Matrix multiplication strided loop using NumPy 2.0 API.
89+
* Implements general matrix multiplication for arbitrary dimensions.
90+
*
91+
* For matmul with signature (m?,n),(n,p?)->(m?,p?):
92+
* - dimensions[0] = N (loop dimension, number of batch operations)
93+
* - dimensions[1] = m (rows of first matrix)
94+
* - dimensions[2] = n (cols of first matrix / rows of second matrix)
95+
* - dimensions[3] = p (cols of second matrix)
96+
*
97+
* - strides[0], strides[1], strides[2] = batch strides for A, B, C
98+
* - strides[3], strides[4] = row stride, col stride for A (m, n)
99+
* - strides[5], strides[6] = row stride, col stride for B (n, p)
100+
* - strides[7], strides[8] = row stride, col stride for C (m, p)
101+
*/
102+
static int
103+
quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
104+
npy_intp const dimensions[], npy_intp const strides[], NpyAuxData *auxdata)
76105
{
77-
npy_intp N = dimensions[0];
78-
char *in1_ptr = data[0], *in2_ptr = data[1];
79-
char *out_ptr = data[2];
80-
npy_intp in1_stride = strides[0];
81-
npy_intp in2_stride = strides[1];
82-
npy_intp out_stride = strides[2];
83-
106+
// Extract dimensions
107+
npy_intp N = dimensions[0]; // Number of batch operations
108+
npy_intp m = dimensions[1]; // Rows of first matrix
109+
npy_intp n = dimensions[2]; // Cols of first matrix / rows of second matrix
110+
npy_intp p = dimensions[3]; // Cols of second matrix
111+
112+
// Extract batch strides
113+
npy_intp A_batch_stride = strides[0];
114+
npy_intp B_batch_stride = strides[1];
115+
npy_intp C_batch_stride = strides[2];
116+
117+
// Extract core strides for matrix dimensions
118+
npy_intp A_row_stride = strides[3]; // Stride along m dimension of A
119+
npy_intp A_col_stride = strides[4]; // Stride along n dimension of A
120+
npy_intp B_row_stride = strides[5]; // Stride along n dimension of B
121+
npy_intp B_col_stride = strides[6]; // Stride along p dimension of B
122+
npy_intp C_row_stride = strides[7]; // Stride along m dimension of C
123+
npy_intp C_col_stride = strides[8]; // Stride along p dimension of C
124+
125+
// Get backend from descriptor
84126
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
85127
QuadBackendType backend = descr->backend;
128+
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
86129

87-
while (N--) {
88-
if (backend == BACKEND_SLEEF) {
89-
*(Sleef_quad *)out_ptr = sleef_op((Sleef_quad *)in1_ptr, (Sleef_quad *)in2_ptr);
90-
}
91-
else {
92-
*(long double *)out_ptr = longdouble_op((long double *)in1_ptr, (long double *)in2_ptr);
130+
// Process each batch
131+
for (npy_intp batch = 0; batch < N; batch++) {
132+
char *A_batch = data[0] + batch * A_batch_stride;
133+
char *B_batch = data[1] + batch * B_batch_stride;
134+
char *C_batch = data[2] + batch * C_batch_stride;
135+
136+
// Perform matrix multiplication: C = A @ B
137+
// C[i,j] = sum_k(A[i,k] * B[k,j])
138+
for (npy_intp i = 0; i < m; i++) {
139+
for (npy_intp j = 0; j < p; j++) {
140+
char *C_ij = C_batch + i * C_row_stride + j * C_col_stride;
141+
142+
if (backend == BACKEND_SLEEF) {
143+
Sleef_quad sum = Sleef_cast_from_doubleq1(0.0); // Initialize to 0
144+
145+
for (npy_intp k = 0; k < n; k++) {
146+
char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
147+
char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
148+
149+
Sleef_quad a_val, b_val;
150+
memcpy(&a_val, A_ik, sizeof(Sleef_quad));
151+
memcpy(&b_val, B_kj, sizeof(Sleef_quad));
152+
153+
// sum += A[i,k] * B[k,j]
154+
sum = Sleef_addq1_u05(sum, Sleef_mulq1_u05(a_val, b_val));
155+
}
156+
157+
memcpy(C_ij, &sum, sizeof(Sleef_quad));
158+
}
159+
else {
160+
// Long double backend
161+
long double sum = 0.0L;
162+
163+
for (npy_intp k = 0; k < n; k++) {
164+
char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
165+
char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
166+
167+
long double a_val, b_val;
168+
memcpy(&a_val, A_ik, sizeof(long double));
169+
memcpy(&b_val, B_kj, sizeof(long double));
170+
171+
sum += a_val * b_val;
172+
}
173+
174+
memcpy(C_ij, &sum, sizeof(long double));
175+
}
176+
}
93177
}
94-
95-
in1_ptr += in1_stride;
96-
in2_ptr += in2_stride;
97-
out_ptr += out_stride;
98178
}
179+
99180
return 0;
100181
}
101182

102-
template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
183+
/**
184+
* Register matmul support following the exact same pattern as binary_ops.cpp
185+
*/
103186
int
104-
create_matmul_ufunc(PyObject *numpy, const char *ufunc_name)
187+
init_matmul_ops(PyObject *numpy)
105188
{
106-
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
189+
printf("DEBUG: init_matmul_ops - registering matmul using NumPy 2.0 API\n");
190+
191+
// Get the existing matmul ufunc - same pattern as binary_ops
192+
PyObject *ufunc = PyObject_GetAttrString(numpy, "matmul");
107193
if (ufunc == NULL) {
194+
printf("DEBUG: Failed to get numpy.matmul\n");
108195
return -1;
109196
}
110197

198+
// Use the same pattern as binary_ops.cpp
111199
PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
112200

113-
PyType_Slot slots[] = {
114-
{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
115-
{NPY_METH_strided_loop,
116-
(void *)&quad_generic_matmul_strided_loop_aligned<sleef_op, longdouble_op>},
117-
{NPY_METH_unaligned_strided_loop,
118-
(void *)&quad_generic_matmul_strided_loop_unaligned<sleef_op, longdouble_op>},
119-
{0, NULL}};
201+
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
202+
{NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop},
203+
{NPY_METH_unaligned_strided_loop, (void *)&quad_matmul_strided_loop},
204+
{0, NULL}};
120205

121206
PyArrayMethod_Spec Spec = {
122207
.name = "quad_matmul",
123208
.nin = 2,
124209
.nout = 1,
125210
.casting = NPY_NO_CASTING,
126-
.flags = (NPY_ARRAYMETHOD_FLAGS)(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_IS_REORDERABLE),
211+
.flags = NPY_METH_SUPPORTS_UNALIGNED,
127212
.dtypes = dtypes,
128213
.slots = slots,
129214
};
130215

216+
printf("DEBUG: About to add loop to matmul ufunc...\n");
217+
131218
if (PyUFunc_AddLoopFromSpec(ufunc, &Spec) < 0) {
219+
printf("DEBUG: Failed to add loop to matmul ufunc\n");
220+
Py_DECREF(ufunc);
132221
return -1;
133222
}
134-
// my guess we don't need any promoter here as of now, since matmul is quad specific
135-
return 0;
136-
}
137223

138-
int
139-
init_matmul_ops(PyObject *numpy)
140-
{
141-
if (create_matmul_ufunc<quad_add, ld_add>(numpy, "matmul") < 0) {
224+
printf("DEBUG: Successfully added matmul loop!\n");
225+
226+
// Add promoter following binary_ops pattern
227+
PyObject *promoter_capsule =
228+
PyCapsule_New((void *)&quad_ufunc_promoter, "numpy._ufunc_promoter", NULL);
229+
if (promoter_capsule == NULL) {
230+
Py_DECREF(ufunc);
231+
return -1;
232+
}
233+
234+
PyObject *DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArrayDescr_Type);
235+
if (DTypes == NULL) {
236+
Py_DECREF(promoter_capsule);
237+
Py_DECREF(ufunc);
142238
return -1;
143239
}
240+
241+
if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
242+
printf("DEBUG: Failed to add promoter (continuing anyway)\n");
243+
PyErr_Clear(); // Don't fail if promoter fails
244+
}
245+
else {
246+
printf("DEBUG: Successfully added promoter\n");
247+
}
248+
249+
Py_DECREF(DTypes);
250+
Py_DECREF(promoter_capsule);
251+
Py_DECREF(ufunc);
252+
253+
printf("DEBUG: init_matmul_ops completed successfully\n");
144254
return 0;
145-
}
255+
}
Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,40 @@
1-
#ifndef _QUADDTYPE_MATMUL_OPS_H
2-
#define _QUADDTYPE_MATMUL_OPS_H
1+
#ifndef _QUADDTYPE_MATMUL_H
2+
#define _QUADDTYPE_MATMUL_H
3+
4+
/**
5+
* Quad Precision Matrix Multiplication for NumPy
6+
*
7+
* This module implements matrix multiplication functionality for the QuadPrecDType
8+
* by registering custom loops with numpy's matmul generalized ufunc.
9+
*
10+
* Supports all matmul operation types:
11+
* - Vector-vector (dot product): (n,) @ (n,) -> scalar
12+
* - Matrix-vector: (m,n) @ (n,) -> (m,)
13+
* - Vector-matrix: (n,) @ (n,p) -> (p,)
14+
* - Matrix-matrix: (m,n) @ (n,p) -> (m,p)
15+
*
16+
* Uses naive algorithms optimized for correctness rather than performance.
17+
* For production use, consider integration with QBLAS optimized routines.
18+
*/
319

420
#include <Python.h>
521

22+
#ifdef __cplusplus
23+
extern "C" {
24+
#endif
25+
26+
/**
27+
* Initialize the matmul operations for the quad precision dtype.
28+
* This function registers the matmul generalized ufunc with numpy.
29+
*
30+
* @param numpy The numpy module object
31+
* @return 0 on success, -1 on failure
32+
*/
633
int
734
init_matmul_ops(PyObject *numpy);
8-
#endif
35+
36+
#ifdef __cplusplus
37+
}
38+
#endif
39+
40+
#endif // _QUADDTYPE_MATMUL_H

0 commit comments

Comments
 (0)