@@ -24,10 +24,6 @@ extern "C" {
2424#include " promoters.hpp"
2525#include " ../quadblas_interface.h"
2626
27- /* *
28- * Resolve descriptors for matmul operation.
29- * Only supports SLEEF backend when QBLAS is enabled.
30- */
3127static NPY_CASTING
3228quad_matmul_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *const dtypes[],
3329 PyArray_Descr *const given_descrs[], PyArray_Descr *loop_descrs[],
@@ -76,23 +72,15 @@ quad_matmul_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[
7672 return casting;
7773}
7874
79- /* *
80- * Determine the type of operation based on input dimensions
81- */
8275enum MatmulOperationType {
83- MATMUL_DOT, // 1D x 1D -> scalar
84- MATMUL_GEMV, // 2D x 1D -> 1D
85- MATMUL_GEMM // 2D x 2D -> 2D
76+ MATMUL_DOT,
77+ MATMUL_GEMV,
78+ MATMUL_GEMM
8679};
8780
8881static MatmulOperationType
8982determine_operation_type (npy_intp m, npy_intp n, npy_intp p)
9083{
91- // For matmul signature (m?,n),(n,p?)->(m?,p?):
92- // - If m=1 and p=1: vector dot product (1D x 1D)
93- // - If p=1: matrix-vector multiplication (2D x 1D)
94- // - Otherwise: matrix-matrix multiplication (2D x 2D)
95-
9684 if (m == 1 && p == 1 ) {
9785 return MATMUL_DOT;
9886 }
@@ -104,10 +92,6 @@ determine_operation_type(npy_intp m, npy_intp n, npy_intp p)
10492 }
10593}
10694
107- /* *
108- * Matrix multiplication strided loop using QBLAS.
109- * Automatically selects the appropriate QBLAS operation based on input dimensions.
110- */
11195static int
11296quad_matmul_strided_loop (PyArrayMethod_Context *context, char *const data[],
11397 npy_intp const dimensions[], npy_intp const strides[], NpyAuxData *auxdata)
@@ -118,43 +102,29 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
118102 npy_intp n = dimensions[2 ]; // Cols of first matrix / rows of second matrix
119103 npy_intp p = dimensions[3 ]; // Cols of second matrix
120104
121- // Extract batch strides
105+ // batch strides
122106 npy_intp A_stride = strides[0 ];
123107 npy_intp B_stride = strides[1 ];
124108 npy_intp C_stride = strides[2 ];
125109
126- // Extract core strides for matrix dimensions
110+ // core strides for matrix dimensions
127111 npy_intp A_row_stride = strides[3 ];
128112 npy_intp A_col_stride = strides[4 ];
129113 npy_intp B_row_stride = strides[5 ];
130114 npy_intp B_col_stride = strides[6 ];
131115 npy_intp C_row_stride = strides[7 ];
132116 npy_intp C_col_stride = strides[8 ];
133117
134- // Note: B_col_stride and C_col_stride not needed for row-major QBLAS calls
135-
136- // Get backend from descriptor (should be SLEEF only)
137118 QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors [0 ];
138119 if (descr->backend != BACKEND_SLEEF) {
139120 PyErr_SetString (PyExc_RuntimeError, " Internal error: non-SLEEF backend in QBLAS matmul" );
140121 return -1 ;
141122 }
142123
143- // Determine operation type
144124 MatmulOperationType op_type = determine_operation_type (m, n, p);
145-
146- // Constants for QBLAS
147125 Sleef_quad alpha = Sleef_cast_from_doubleq1 (1.0 );
148126 Sleef_quad beta = Sleef_cast_from_doubleq1 (0.0 );
149127
150- // print all information for debugging
151- printf (" DEBUG: Performing %ld batch operations with dimensions (%ld, %ld, %ld)\n " , (long )N,
152- (long )m, (long )n, (long )p);
153- printf (" DEBUG: Strides - A: (%ld, %ld), B: (%ld, %ld), C: (%ld, %ld)\n " , (long )A_row_stride,
154- (long )A_col_stride, (long )B_row_stride, (long )B_col_stride, (long )C_row_stride,
155- (long )C_col_stride);
156- printf (" DEBUG: Operation type: %d\n " , op_type);
157-
158128 char *A = data[0 ];
159129 char *B = data[1 ];
160130 char *C = data[2 ];
@@ -167,13 +137,6 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
167137
168138 switch (op_type) {
169139 case MATMUL_DOT: {
170- // Vector dot product: C = A^T * B (both are vectors)
171- // A has shape (1, n), B has shape (n, 1), C has shape (1, 1)
172-
173- printf (" DEBUG: Using QBLAS dot product for %ld elements\n " , (long )n);
174-
175- // A is effectively a vector of length n
176- // B is effectively a vector of length n
177140 size_t incx = A_col_stride / sizeof (Sleef_quad);
178141 size_t incy = B_row_stride / sizeof (Sleef_quad);
179142
@@ -182,12 +145,6 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
182145 }
183146
184147 case MATMUL_GEMV: {
185- // Matrix-vector multiplication: C = A * B
186- // A has shape (m, n), B has shape (n, 1), C has shape (m, 1)
187-
188- printf (" DEBUG: Using QBLAS GEMV for %ldx%ld matrix times %ld vector\n " , (long )m,
189- (long )n, (long )n);
190-
191148 size_t lda = A_row_stride / sizeof (Sleef_quad);
192149 size_t incx = B_row_stride / sizeof (Sleef_quad);
193150 size_t incy = C_row_stride / sizeof (Sleef_quad);
@@ -198,17 +155,46 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
198155 }
199156
200157 case MATMUL_GEMM: {
201- // Matrix-matrix multiplication: C = A * B
202- // A has shape (m, n), B has shape (n, p), C has shape (m, p)
203-
204- printf (" DEBUG: Using QBLAS GEMM for %ldx%ldx%ld matrices\n " , (long )m, (long )n, (long )p);
205-
206158 size_t lda = A_row_stride / sizeof (Sleef_quad);
207159 size_t ldb = B_row_stride / sizeof (Sleef_quad);
208- size_t ldc = C_row_stride / sizeof (Sleef_quad);
160+ size_t ldc_numpy = C_row_stride / sizeof (Sleef_quad);
161+
162+ Sleef_quad *temp_A_buffer = new Sleef_quad[m * n];
163+ if (!temp_A_buffer) {
164+ PyErr_SetString (PyExc_MemoryError, " Failed to allocate temporary buffer for GEMM" );
165+ delete[] temp_A_buffer;
166+ return -1 ;
167+ }
168+ Sleef_quad *temp_B_buffer = new Sleef_quad[n * p];
169+ if (!temp_B_buffer) {
170+ PyErr_SetString (PyExc_MemoryError, " Failed to allocate temporary buffer for GEMM" );
171+ delete[] temp_A_buffer;
172+ return -1 ;
173+ }
174+ memcpy (temp_A_buffer, A_ptr, m * n * sizeof (Sleef_quad));
175+ memcpy (temp_B_buffer, B_ptr, n * p * sizeof (Sleef_quad));
176+ A_ptr = temp_A_buffer;
177+ B_ptr = temp_B_buffer;
178+
179+ Sleef_quad *temp_C_buffer = new Sleef_quad[m * p];
180+ if (!temp_C_buffer) {
181+ PyErr_SetString (PyExc_MemoryError,
182+ " Failed to allocate temporary buffer for GEMM result" );
183+ return -1 ;
184+ }
185+
186+ size_t ldc_temp = p;
209187
210188 result = qblas_gemm (' R' , ' N' , ' N' , m, p, n, &alpha, A_ptr, lda, B_ptr, ldb, &beta,
211- C_ptr, ldc);
189+ temp_C_buffer, ldc_temp);
190+
191+ if (result == 0 ) {
192+ memcpy (C_ptr, temp_C_buffer, m * p * sizeof (Sleef_quad));
193+ }
194+
195+ delete[] temp_C_buffer;
196+ delete[] temp_A_buffer;
197+ delete[] temp_B_buffer;
212198 break ;
213199 }
214200 }
@@ -221,27 +207,91 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
221207 return 0 ;
222208}
223209
224- /* *
225- * Register matmul support with QBLAS acceleration
226- */
210+ static int
211+ naive_matmul_strided_loop (PyArrayMethod_Context *context, char *const data[],
212+ npy_intp const dimensions[], npy_intp const strides[],
213+ NpyAuxData *auxdata)
214+ {
215+ npy_intp N = dimensions[0 ];
216+ npy_intp m = dimensions[1 ];
217+ npy_intp n = dimensions[2 ];
218+ npy_intp p = dimensions[3 ];
219+
220+ npy_intp A_batch_stride = strides[0 ];
221+ npy_intp B_batch_stride = strides[1 ];
222+ npy_intp C_batch_stride = strides[2 ];
223+
224+ npy_intp A_row_stride = strides[3 ];
225+ npy_intp A_col_stride = strides[4 ];
226+ npy_intp B_row_stride = strides[5 ];
227+ npy_intp B_col_stride = strides[6 ];
228+ npy_intp C_row_stride = strides[7 ];
229+ npy_intp C_col_stride = strides[8 ];
230+
231+ QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors [0 ];
232+ QuadBackendType backend = descr->backend ;
233+ size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof (Sleef_quad) : sizeof (long double );
234+
235+ for (npy_intp batch = 0 ; batch < N; batch++) {
236+ char *A_batch = data[0 ] + batch * A_batch_stride;
237+ char *B_batch = data[1 ] + batch * B_batch_stride;
238+ char *C_batch = data[2 ] + batch * C_batch_stride;
239+
240+ for (npy_intp i = 0 ; i < m; i++) {
241+ for (npy_intp j = 0 ; j < p; j++) {
242+ char *C_ij = C_batch + i * C_row_stride + j * C_col_stride;
243+
244+ if (backend == BACKEND_SLEEF) {
245+ Sleef_quad sum = Sleef_cast_from_doubleq1 (0.0 );
246+
247+ for (npy_intp k = 0 ; k < n; k++) {
248+ char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
249+ char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
250+
251+ Sleef_quad a_val, b_val;
252+ memcpy (&a_val, A_ik, sizeof (Sleef_quad));
253+ memcpy (&b_val, B_kj, sizeof (Sleef_quad));
254+ sum = Sleef_fmaq1_u05 (a_val, b_val, sum);
255+ }
256+
257+ memcpy (C_ij, &sum, sizeof (Sleef_quad));
258+ }
259+ else {
260+ long double sum = 0 .0L ;
261+
262+ for (npy_intp k = 0 ; k < n; k++) {
263+ char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
264+ char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
265+
266+ long double a_val, b_val;
267+ memcpy (&a_val, A_ik, sizeof (long double ));
268+ memcpy (&b_val, B_kj, sizeof (long double ));
269+
270+ sum += a_val * b_val;
271+ }
272+
273+ memcpy (C_ij, &sum, sizeof (long double ));
274+ }
275+ }
276+ }
277+ }
278+
279+ return 0 ;
280+ }
281+
227282int
228283init_matmul_ops (PyObject *numpy)
229284{
230- printf (" DEBUG: init_matmul_ops - registering QBLAS-accelerated matmul\n " );
231-
232- // Get the existing matmul ufunc
233285 PyObject *ufunc = PyObject_GetAttrString (numpy, " matmul" );
234286 if (ufunc == NULL ) {
235- printf (" DEBUG: Failed to get numpy.matmul\n " );
236287 return -1 ;
237288 }
238289
239- // Setup method specification for QBLAS-accelerated matmul
240290 PyArray_DTypeMeta *dtypes[3 ] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
241291
242292 PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
243- {NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop },
244- {NPY_METH_unaligned_strided_loop, (void *)&quad_matmul_strided_loop },
293+ {NPY_METH_strided_loop, (void *)&naive_matmul_strided_loop },
294+ {NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop },
245295 {0 , NULL }};
246296
247297 PyArrayMethod_Spec Spec = {
@@ -254,17 +304,11 @@ init_matmul_ops(PyObject *numpy)
254304 .slots = slots,
255305 };
256306
257- printf (" DEBUG: About to add QBLAS loop to matmul ufunc...\n " );
258-
259307 if (PyUFunc_AddLoopFromSpec (ufunc, &Spec) < 0 ) {
260- printf (" DEBUG: Failed to add QBLAS loop to matmul ufunc\n " );
261308 Py_DECREF (ufunc);
262309 return -1 ;
263310 }
264311
265- printf (" DEBUG: Successfully added QBLAS matmul loop!\n " );
266-
267- // Add promoter
268312 PyObject *promoter_capsule =
269313 PyCapsule_New ((void *)&quad_ufunc_promoter, " numpy._ufunc_promoter" , NULL );
270314 if (promoter_capsule == NULL ) {
@@ -280,17 +324,14 @@ init_matmul_ops(PyObject *numpy)
280324 }
281325
282326 if (PyUFunc_AddPromoter (ufunc, DTypes, promoter_capsule) < 0 ) {
283- printf (" DEBUG: Failed to add promoter (continuing anyway)\n " );
284327 PyErr_Clear (); // Don't fail if promoter fails
285328 }
286329 else {
287- printf (" DEBUG: Successfully added promoter\n " );
288330 }
289331
290332 Py_DECREF (DTypes);
291333 Py_DECREF (promoter_capsule);
292334 Py_DECREF (ufunc);
293335
294- printf (" DEBUG: init_matmul_ops completed successfully with QBLAS acceleration\n " );
295336 return 0 ;
296337}
0 commit comments