5
5
#define NO_IMPORT_ARRAY
6
6
#define NO_IMPORT_UFUNC
7
7
8
+ extern " C" {
8
9
#include < Python.h>
9
10
#include < cstdio>
11
+ #include < string.h>
10
12
11
13
#include " numpy/arrayobject.h"
14
+ #include " numpy/ndarraytypes.h"
12
15
#include " numpy/ufuncobject.h"
13
16
#include " numpy/dtype_api.h"
14
- # include " numpy/ndarraytypes.h "
17
+ }
15
18
16
19
#include " ../quad_common.h"
17
20
#include " ../scalar.h"
18
21
#include " ../dtype.h"
19
22
#include " ../ops.hpp"
20
- #include " binary_ops.h"
21
23
#include " matmul.h"
24
+ #include " promoters.hpp"
22
25
23
- #include < iostream>
24
-
26
+ /* *
27
+ * Resolve descriptors for matmul operation.
28
+ * Follows the same pattern as binary_ops.cpp
29
+ */
25
30
static NPY_CASTING
26
31
quad_matmul_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *const dtypes[],
27
32
PyArray_Descr *const given_descrs[], PyArray_Descr *loop_descrs[],
28
33
npy_intp *NPY_UNUSED (view_offset))
29
34
{
30
- NPY_CASTING casting = NPY_NO_CASTING;
31
- std::cout << " exiting the descriptor " ;
32
- return casting ;
33
- }
35
+ // Follow the exact same pattern as quad_binary_op_resolve_descriptors
36
+ QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[ 0 ] ;
37
+ QuadPrecDTypeObject *descr_in2 = (QuadPrecDTypeObject *)given_descrs[ 1 ] ;
38
+ QuadBackendType target_backend;
34
39
35
- template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
36
- int
37
- quad_generic_matmul_strided_loop_unaligned (PyArrayMethod_Context *context, char *const data[],
38
- npy_intp const dimensions[], npy_intp const strides[],
39
- NpyAuxData *auxdata)
40
- {
41
- npy_intp N = dimensions[0 ];
42
- char *in1_ptr = data[0 ], *in2_ptr = data[1 ];
43
- char *out_ptr = data[2 ];
44
- npy_intp in1_stride = strides[0 ];
45
- npy_intp in2_stride = strides[1 ];
46
- npy_intp out_stride = strides[2 ];
47
-
48
- QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors [0 ];
49
- QuadBackendType backend = descr->backend ;
50
- size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof (Sleef_quad) : sizeof (long double );
40
+ // Determine target backend and if casting is needed
41
+ NPY_CASTING casting = NPY_NO_CASTING;
42
+ if (descr_in1->backend != descr_in2->backend ) {
43
+ target_backend = BACKEND_LONGDOUBLE;
44
+ casting = NPY_SAFE_CASTING;
45
+ }
46
+ else {
47
+ target_backend = descr_in1->backend ;
48
+ }
51
49
52
- quad_value in1, in2, out;
53
- while (N--) {
54
- memcpy (&in1, in1_ptr, elem_size);
55
- memcpy (&in2, in2_ptr, elem_size);
56
- if (backend == BACKEND_SLEEF) {
57
- out.sleef_value = sleef_op (&in1.sleef_value , &in2.sleef_value );
50
+ // Set up input descriptors, casting if necessary
51
+ for (int i = 0 ; i < 2 ; i++) {
52
+ if (((QuadPrecDTypeObject *)given_descrs[i])->backend != target_backend) {
53
+ loop_descrs[i] = (PyArray_Descr *)new_quaddtype_instance (target_backend);
54
+ if (!loop_descrs[i]) {
55
+ return (NPY_CASTING)-1 ;
56
+ }
58
57
}
59
58
else {
60
- out.longdouble_value = longdouble_op (&in1.longdouble_value , &in2.longdouble_value );
59
+ Py_INCREF (given_descrs[i]);
60
+ loop_descrs[i] = given_descrs[i];
61
61
}
62
- memcpy (out_ptr, &out, elem_size);
62
+ }
63
63
64
- in1_ptr += in1_stride;
65
- in2_ptr += in2_stride;
66
- out_ptr += out_stride;
64
+ // Set up output descriptor
65
+ if (given_descrs[2 ] == NULL ) {
66
+ loop_descrs[2 ] = (PyArray_Descr *)new_quaddtype_instance (target_backend);
67
+ if (!loop_descrs[2 ]) {
68
+ return (NPY_CASTING)-1 ;
69
+ }
67
70
}
68
- return 0 ;
71
+ else {
72
+ QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)given_descrs[2 ];
73
+ if (descr_out->backend != target_backend) {
74
+ loop_descrs[2 ] = (PyArray_Descr *)new_quaddtype_instance (target_backend);
75
+ if (!loop_descrs[2 ]) {
76
+ return (NPY_CASTING)-1 ;
77
+ }
78
+ }
79
+ else {
80
+ Py_INCREF (given_descrs[2 ]);
81
+ loop_descrs[2 ] = given_descrs[2 ];
82
+ }
83
+ }
84
+ return casting;
69
85
}
70
86
71
- template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
72
- int
73
- quad_generic_matmul_strided_loop_aligned (PyArrayMethod_Context *context, char *const data[],
74
- npy_intp const dimensions[], npy_intp const strides[],
75
- NpyAuxData *auxdata)
87
+ /* *
88
+ * Matrix multiplication strided loop using NumPy 2.0 API.
89
+ * Implements general matrix multiplication for arbitrary dimensions.
90
+ *
91
+ * For matmul with signature (m?,n),(n,p?)->(m?,p?):
92
+ * - dimensions[0] = N (loop dimension, number of batch operations)
93
+ * - dimensions[1] = m (rows of first matrix)
94
+ * - dimensions[2] = n (cols of first matrix / rows of second matrix)
95
+ * - dimensions[3] = p (cols of second matrix)
96
+ *
97
+ * - strides[0], strides[1], strides[2] = batch strides for A, B, C
98
+ * - strides[3], strides[4] = row stride, col stride for A (m, n)
99
+ * - strides[5], strides[6] = row stride, col stride for B (n, p)
100
+ * - strides[7], strides[8] = row stride, col stride for C (m, p)
101
+ */
102
+ static int
103
+ quad_matmul_strided_loop (PyArrayMethod_Context *context, char *const data[],
104
+ npy_intp const dimensions[], npy_intp const strides[], NpyAuxData *auxdata)
76
105
{
77
- npy_intp N = dimensions[0 ];
78
- char *in1_ptr = data[0 ], *in2_ptr = data[1 ];
79
- char *out_ptr = data[2 ];
80
- npy_intp in1_stride = strides[0 ];
81
- npy_intp in2_stride = strides[1 ];
82
- npy_intp out_stride = strides[2 ];
83
-
106
+ // Extract dimensions
107
+ npy_intp N = dimensions[0 ]; // Number of batch operations
108
+ npy_intp m = dimensions[1 ]; // Rows of first matrix
109
+ npy_intp n = dimensions[2 ]; // Cols of first matrix / rows of second matrix
110
+ npy_intp p = dimensions[3 ]; // Cols of second matrix
111
+
112
+ // Extract batch strides
113
+ npy_intp A_batch_stride = strides[0 ];
114
+ npy_intp B_batch_stride = strides[1 ];
115
+ npy_intp C_batch_stride = strides[2 ];
116
+
117
+ // Extract core strides for matrix dimensions
118
+ npy_intp A_row_stride = strides[3 ]; // Stride along m dimension of A
119
+ npy_intp A_col_stride = strides[4 ]; // Stride along n dimension of A
120
+ npy_intp B_row_stride = strides[5 ]; // Stride along n dimension of B
121
+ npy_intp B_col_stride = strides[6 ]; // Stride along p dimension of B
122
+ npy_intp C_row_stride = strides[7 ]; // Stride along m dimension of C
123
+ npy_intp C_col_stride = strides[8 ]; // Stride along p dimension of C
124
+
125
+ // Get backend from descriptor
84
126
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors [0 ];
85
127
QuadBackendType backend = descr->backend ;
128
+ size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof (Sleef_quad) : sizeof (long double );
86
129
87
- while (N--) {
88
- if (backend == BACKEND_SLEEF) {
89
- *(Sleef_quad *)out_ptr = sleef_op ((Sleef_quad *)in1_ptr, (Sleef_quad *)in2_ptr);
90
- }
91
- else {
92
- *(long double *)out_ptr = longdouble_op ((long double *)in1_ptr, (long double *)in2_ptr);
130
+ // Process each batch
131
+ for (npy_intp batch = 0 ; batch < N; batch++) {
132
+ char *A_batch = data[0 ] + batch * A_batch_stride;
133
+ char *B_batch = data[1 ] + batch * B_batch_stride;
134
+ char *C_batch = data[2 ] + batch * C_batch_stride;
135
+
136
+ // Perform matrix multiplication: C = A @ B
137
+ // C[i,j] = sum_k(A[i,k] * B[k,j])
138
+ for (npy_intp i = 0 ; i < m; i++) {
139
+ for (npy_intp j = 0 ; j < p; j++) {
140
+ char *C_ij = C_batch + i * C_row_stride + j * C_col_stride;
141
+
142
+ if (backend == BACKEND_SLEEF) {
143
+ Sleef_quad sum = Sleef_cast_from_doubleq1 (0.0 ); // Initialize to 0
144
+
145
+ for (npy_intp k = 0 ; k < n; k++) {
146
+ char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
147
+ char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
148
+
149
+ Sleef_quad a_val, b_val;
150
+ memcpy (&a_val, A_ik, sizeof (Sleef_quad));
151
+ memcpy (&b_val, B_kj, sizeof (Sleef_quad));
152
+
153
+ // sum += A[i,k] * B[k,j]
154
+ sum = Sleef_addq1_u05 (sum, Sleef_mulq1_u05 (a_val, b_val));
155
+ }
156
+
157
+ memcpy (C_ij, &sum, sizeof (Sleef_quad));
158
+ }
159
+ else {
160
+ // Long double backend
161
+ long double sum = 0 .0L ;
162
+
163
+ for (npy_intp k = 0 ; k < n; k++) {
164
+ char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
165
+ char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
166
+
167
+ long double a_val, b_val;
168
+ memcpy (&a_val, A_ik, sizeof (long double ));
169
+ memcpy (&b_val, B_kj, sizeof (long double ));
170
+
171
+ sum += a_val * b_val;
172
+ }
173
+
174
+ memcpy (C_ij, &sum, sizeof (long double ));
175
+ }
176
+ }
93
177
}
94
-
95
- in1_ptr += in1_stride;
96
- in2_ptr += in2_stride;
97
- out_ptr += out_stride;
98
178
}
179
+
99
180
return 0 ;
100
181
}
101
182
102
- template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
183
+ /* *
184
+ * Register matmul support following the exact same pattern as binary_ops.cpp
185
+ */
103
186
int
104
- create_matmul_ufunc (PyObject *numpy, const char *ufunc_name )
187
+ init_matmul_ops (PyObject *numpy)
105
188
{
106
- PyObject *ufunc = PyObject_GetAttrString (numpy, ufunc_name);
189
+ printf (" DEBUG: init_matmul_ops - registering matmul using NumPy 2.0 API\n " );
190
+
191
+ // Get the existing matmul ufunc - same pattern as binary_ops
192
+ PyObject *ufunc = PyObject_GetAttrString (numpy, " matmul" );
107
193
if (ufunc == NULL ) {
194
+ printf (" DEBUG: Failed to get numpy.matmul\n " );
108
195
return -1 ;
109
196
}
110
197
198
+ // Use the same pattern as binary_ops.cpp
111
199
PyArray_DTypeMeta *dtypes[3 ] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
112
200
113
- PyType_Slot slots[] = {
114
- {NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
115
- {NPY_METH_strided_loop,
116
- (void *)&quad_generic_matmul_strided_loop_aligned<sleef_op, longdouble_op>},
117
- {NPY_METH_unaligned_strided_loop,
118
- (void *)&quad_generic_matmul_strided_loop_unaligned<sleef_op, longdouble_op>},
119
- {0 , NULL }};
201
+ PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
202
+ {NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop},
203
+ {NPY_METH_unaligned_strided_loop, (void *)&quad_matmul_strided_loop},
204
+ {0 , NULL }};
120
205
121
206
PyArrayMethod_Spec Spec = {
122
207
.name = " quad_matmul" ,
123
208
.nin = 2 ,
124
209
.nout = 1 ,
125
210
.casting = NPY_NO_CASTING,
126
- .flags = (NPY_ARRAYMETHOD_FLAGS)( NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_IS_REORDERABLE) ,
211
+ .flags = NPY_METH_SUPPORTS_UNALIGNED,
127
212
.dtypes = dtypes,
128
213
.slots = slots,
129
214
};
130
215
216
+ printf (" DEBUG: About to add loop to matmul ufunc...\n " );
217
+
131
218
if (PyUFunc_AddLoopFromSpec (ufunc, &Spec) < 0 ) {
219
+ printf (" DEBUG: Failed to add loop to matmul ufunc\n " );
220
+ Py_DECREF (ufunc);
132
221
return -1 ;
133
222
}
134
- // my guess we don't need any promoter here as of now, since matmul is quad specific
135
- return 0 ;
136
- }
137
223
138
- int
139
- init_matmul_ops (PyObject *numpy)
140
- {
141
- if (create_matmul_ufunc<quad_add, ld_add>(numpy, " matmul" ) < 0 ) {
224
+ printf (" DEBUG: Successfully added matmul loop!\n " );
225
+
226
+ // Add promoter following binary_ops pattern
227
+ PyObject *promoter_capsule =
228
+ PyCapsule_New ((void *)&quad_ufunc_promoter, " numpy._ufunc_promoter" , NULL );
229
+ if (promoter_capsule == NULL ) {
230
+ Py_DECREF (ufunc);
231
+ return -1 ;
232
+ }
233
+
234
+ PyObject *DTypes = PyTuple_Pack (3 , &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArrayDescr_Type);
235
+ if (DTypes == NULL ) {
236
+ Py_DECREF (promoter_capsule);
237
+ Py_DECREF (ufunc);
142
238
return -1 ;
143
239
}
240
+
241
+ if (PyUFunc_AddPromoter (ufunc, DTypes, promoter_capsule) < 0 ) {
242
+ printf (" DEBUG: Failed to add promoter (continuing anyway)\n " );
243
+ PyErr_Clear (); // Don't fail if promoter fails
244
+ }
245
+ else {
246
+ printf (" DEBUG: Successfully added promoter\n " );
247
+ }
248
+
249
+ Py_DECREF (DTypes);
250
+ Py_DECREF (promoter_capsule);
251
+ Py_DECREF (ufunc);
252
+
253
+ printf (" DEBUG: init_matmul_ops completed successfully\n " );
144
254
return 0 ;
145
- }
255
+ }
0 commit comments