@@ -93,8 +93,9 @@ determine_operation_type(npy_intp m, npy_intp n, npy_intp p)
93
93
}
94
94
95
95
static int
96
- quad_matmul_strided_loop (PyArrayMethod_Context *context, char *const data[],
97
- npy_intp const dimensions[], npy_intp const strides[], NpyAuxData *auxdata)
96
+ quad_matmul_strided_loop_aligned (PyArrayMethod_Context *context, char *const data[],
97
+ npy_intp const dimensions[], npy_intp const strides[],
98
+ NpyAuxData *auxdata)
98
99
{
99
100
// Extract dimensions
100
101
npy_intp N = dimensions[0 ]; // Batch size, this remains always 1 for matmul afaik
@@ -149,6 +150,8 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
149
150
size_t incx = B_row_stride / sizeof (Sleef_quad);
150
151
size_t incy = C_row_stride / sizeof (Sleef_quad);
151
152
153
+ memset (C_ptr, 0 , m * p * sizeof (Sleef_quad));
154
+
152
155
result =
153
156
qblas_gemv (' R' , ' N' , m, n, &alpha, A_ptr, lda, B_ptr, incx, &beta, C_ptr, incy);
154
157
break ;
@@ -159,32 +162,132 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
159
162
size_t ldb = B_row_stride / sizeof (Sleef_quad);
160
163
size_t ldc_numpy = C_row_stride / sizeof (Sleef_quad);
161
164
165
+ memset (C_ptr, 0 , m * p * sizeof (Sleef_quad));
166
+
167
+ size_t ldc_temp = p;
168
+
169
+ result = qblas_gemm (' R' , ' N' , ' N' , m, p, n, &alpha, A_ptr, lda, B_ptr, ldb, &beta,
170
+ C_ptr, ldc_numpy);
171
+ break ;
172
+ }
173
+ }
174
+
175
+ if (result != 0 ) {
176
+ PyErr_SetString (PyExc_RuntimeError, " QBLAS operation failed" );
177
+ return -1 ;
178
+ }
179
+
180
+ return 0 ;
181
+ }
182
+
183
+ static int
184
+ quad_matmul_strided_loop_unaligned (PyArrayMethod_Context *context, char *const data[],
185
+ npy_intp const dimensions[], npy_intp const strides[],
186
+ NpyAuxData *auxdata)
187
+ {
188
+ // Extract dimensions
189
+ npy_intp N = dimensions[0 ]; // Batch size, this remains always 1 for matmul afaik
190
+ npy_intp m = dimensions[1 ]; // Rows of first matrix
191
+ npy_intp n = dimensions[2 ]; // Cols of first matrix / rows of second matrix
192
+ npy_intp p = dimensions[3 ]; // Cols of second matrix
193
+
194
+ // batch strides
195
+ npy_intp A_stride = strides[0 ];
196
+ npy_intp B_stride = strides[1 ];
197
+ npy_intp C_stride = strides[2 ];
198
+
199
+ // core strides for matrix dimensions
200
+ npy_intp A_row_stride = strides[3 ];
201
+ npy_intp A_col_stride = strides[4 ];
202
+ npy_intp B_row_stride = strides[5 ];
203
+ npy_intp B_col_stride = strides[6 ];
204
+ npy_intp C_row_stride = strides[7 ];
205
+ npy_intp C_col_stride = strides[8 ];
206
+
207
+ QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors [0 ];
208
+ if (descr->backend != BACKEND_SLEEF) {
209
+ PyErr_SetString (PyExc_RuntimeError, " Internal error: non-SLEEF backend in QBLAS matmul" );
210
+ return -1 ;
211
+ }
212
+
213
+ MatmulOperationType op_type = determine_operation_type (m, n, p);
214
+ Sleef_quad alpha = Sleef_cast_from_doubleq1 (1.0 );
215
+ Sleef_quad beta = Sleef_cast_from_doubleq1 (0.0 );
216
+
217
+ char *A = data[0 ];
218
+ char *B = data[1 ];
219
+ char *C = data[2 ];
220
+
221
+ Sleef_quad *A_ptr = (Sleef_quad *)A;
222
+ Sleef_quad *B_ptr = (Sleef_quad *)B;
223
+ Sleef_quad *C_ptr = (Sleef_quad *)C;
224
+
225
+ int result = -1 ;
226
+
227
+ switch (op_type) {
228
+ case MATMUL_DOT: {
229
+ Sleef_quad *temp_A_buffer = new Sleef_quad[n];
230
+ Sleef_quad *temp_B_buffer = new Sleef_quad[n];
231
+
232
+ memcpy (temp_A_buffer, A_ptr, n * sizeof (Sleef_quad));
233
+ memcpy (temp_B_buffer, B_ptr, n * sizeof (Sleef_quad));
234
+
235
+ size_t incx = 1 ;
236
+ size_t incy = 1 ;
237
+
238
+ result = qblas_dot (n, temp_A_buffer, incx, temp_B_buffer, incy, C_ptr);
239
+
240
+ delete[] temp_A_buffer;
241
+ delete[] temp_B_buffer;
242
+ break ;
243
+ }
244
+
245
+ case MATMUL_GEMV: {
246
+ size_t lda = A_row_stride / sizeof (Sleef_quad);
247
+ size_t incx = B_row_stride / sizeof (Sleef_quad);
248
+ size_t incy = C_row_stride / sizeof (Sleef_quad);
249
+
162
250
Sleef_quad *temp_A_buffer = new Sleef_quad[m * n];
163
- if (!temp_A_buffer) {
164
- PyErr_SetString (PyExc_MemoryError, " Failed to allocate temporary buffer for GEMM" );
165
- delete[] temp_A_buffer;
166
- return -1 ;
167
- }
168
251
Sleef_quad *temp_B_buffer = new Sleef_quad[n * p];
169
- if (!temp_B_buffer) {
170
- PyErr_SetString (PyExc_MemoryError, " Failed to allocate temporary buffer for GEMM" );
171
- delete[] temp_A_buffer;
172
- return -1 ;
173
- }
174
252
memcpy (temp_A_buffer, A_ptr, m * n * sizeof (Sleef_quad));
175
253
memcpy (temp_B_buffer, B_ptr, n * p * sizeof (Sleef_quad));
176
254
A_ptr = temp_A_buffer;
177
255
B_ptr = temp_B_buffer;
178
256
257
+ // Use temp_C_buffer to avoid unaligned writes
179
258
Sleef_quad *temp_C_buffer = new Sleef_quad[m * p];
180
- if (!temp_C_buffer) {
181
- PyErr_SetString (PyExc_MemoryError,
182
- " Failed to allocate temporary buffer for GEMM result" );
183
- return -1 ;
184
- }
185
259
260
+ lda = n;
261
+ incx = 1 ;
262
+ incy = 1 ;
263
+
264
+ memset (temp_C_buffer, 0 , m * p * sizeof (Sleef_quad));
265
+
266
+ result = qblas_gemv (' R' , ' N' , m, n, &alpha, A_ptr, lda, B_ptr, incx, &beta,
267
+ temp_C_buffer, incy);
268
+ break ;
269
+ }
270
+
271
+ case MATMUL_GEMM: {
272
+ size_t lda = A_row_stride / sizeof (Sleef_quad);
273
+ size_t ldb = B_row_stride / sizeof (Sleef_quad);
274
+ size_t ldc_numpy = C_row_stride / sizeof (Sleef_quad);
275
+
276
+ Sleef_quad *temp_A_buffer = new Sleef_quad[m * n];
277
+ Sleef_quad *temp_B_buffer = new Sleef_quad[n * p];
278
+ memcpy (temp_A_buffer, A_ptr, m * n * sizeof (Sleef_quad));
279
+ memcpy (temp_B_buffer, B_ptr, n * p * sizeof (Sleef_quad));
280
+ A_ptr = temp_A_buffer;
281
+ B_ptr = temp_B_buffer;
282
+
283
+ // since these are now contiguous so,
284
+ lda = n;
285
+ ldb = p;
186
286
size_t ldc_temp = p;
187
287
288
+ Sleef_quad *temp_C_buffer = new Sleef_quad[m * p];
289
+ memset (temp_C_buffer, 0 , m * p * sizeof (Sleef_quad));
290
+
188
291
result = qblas_gemm (' R' , ' N' , ' N' , m, p, n, &alpha, A_ptr, lda, B_ptr, ldb, &beta,
189
292
temp_C_buffer, ldc_temp);
190
293
@@ -218,8 +321,8 @@ naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
218
321
npy_intp p = dimensions[3 ];
219
322
220
323
npy_intp A_batch_stride = strides[0 ];
221
- npy_intp B_batch_stride = strides[1 ];
222
- npy_intp C_batch_stride = strides[2 ];
324
+ npy_intp B_stride = strides[1 ];
325
+ npy_intp C_stride = strides[2 ];
223
326
224
327
npy_intp A_row_stride = strides[3 ];
225
328
npy_intp A_col_stride = strides[4 ];
@@ -232,46 +335,44 @@ naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
232
335
QuadBackendType backend = descr->backend ;
233
336
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof (Sleef_quad) : sizeof (long double );
234
337
235
- for (npy_intp batch = 0 ; batch < N; batch++) {
236
- char *A_batch = data[0 ] + batch * A_batch_stride;
237
- char *B_batch = data[1 ] + batch * B_batch_stride;
238
- char *C_batch = data[2 ] + batch * C_batch_stride;
239
-
240
- for (npy_intp i = 0 ; i < m; i++) {
241
- for (npy_intp j = 0 ; j < p; j++) {
242
- char *C_ij = C_batch + i * C_row_stride + j * C_col_stride;
338
+ char *A = data[0 ];
339
+ char *B = data[1 ];
340
+ char *C = data[2 ];
243
341
244
- if (backend == BACKEND_SLEEF) {
245
- Sleef_quad sum = Sleef_cast_from_doubleq1 (0.0 );
342
+ for (npy_intp i = 0 ; i < m; i++) {
343
+ for (npy_intp j = 0 ; j < p; j++) {
344
+ char *C_ij = C + i * C_row_stride + j * C_col_stride;
246
345
247
- for (npy_intp k = 0 ; k < n; k++) {
248
- char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
249
- char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
346
+ if (backend == BACKEND_SLEEF) {
347
+ Sleef_quad sum = Sleef_cast_from_doubleq1 (0.0 );
250
348
251
- Sleef_quad a_val, b_val;
252
- memcpy (&a_val, A_ik, sizeof (Sleef_quad));
253
- memcpy (&b_val, B_kj, sizeof (Sleef_quad));
254
- sum = Sleef_fmaq1_u05 (a_val, b_val, sum);
255
- }
349
+ for (npy_intp k = 0 ; k < n; k++) {
350
+ char *A_ik = A + i * A_row_stride + k * A_col_stride;
351
+ char *B_kj = B + k * B_row_stride + j * B_col_stride;
256
352
257
- memcpy (C_ij, &sum, sizeof (Sleef_quad));
353
+ Sleef_quad a_val, b_val;
354
+ memcpy (&a_val, A_ik, sizeof (Sleef_quad));
355
+ memcpy (&b_val, B_kj, sizeof (Sleef_quad));
356
+ sum = Sleef_fmaq1_u05 (a_val, b_val, sum);
258
357
}
259
- else {
260
- long double sum = 0 .0L ;
261
358
262
- for (npy_intp k = 0 ; k < n; k++) {
263
- char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
264
- char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
359
+ memcpy (C_ij, &sum, sizeof (Sleef_quad));
360
+ }
361
+ else {
362
+ long double sum = 0 .0L ;
265
363
266
- long double a_val, b_val;
267
- memcpy (&a_val, A_ik, sizeof ( long double )) ;
268
- memcpy (&b_val, B_kj, sizeof ( long double )) ;
364
+ for (npy_intp k = 0 ; k < n; k++) {
365
+ char *A_ik = A + i * A_row_stride + k * A_col_stride ;
366
+ char *B_kj = B + k * B_row_stride + j * B_col_stride ;
269
367
270
- sum += a_val * b_val;
271
- }
368
+ long double a_val, b_val;
369
+ memcpy (&a_val, A_ik, sizeof (long double ));
370
+ memcpy (&b_val, B_kj, sizeof (long double ));
272
371
273
- memcpy (C_ij, & sum, sizeof ( long double )) ;
372
+ sum += a_val * b_val ;
274
373
}
374
+
375
+ memcpy (C_ij, &sum, sizeof (long double ));
275
376
}
276
377
}
277
378
}
@@ -289,21 +390,22 @@ init_matmul_ops(PyObject *numpy)
289
390
290
391
PyArray_DTypeMeta *dtypes[3 ] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
291
392
292
- #ifndef DISABLE_QUADBLAS
393
+ #ifndef DISABLE_QUADBLAS
293
394
// set threading to max
294
395
int num_threads = _quadblas_get_num_threads ();
295
396
_quadblas_set_num_threads (num_threads);
296
397
297
- PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
298
- {NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop},
299
- {NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop},
300
- {0 , NULL }};
301
- #else
398
+ PyType_Slot slots[] = {
399
+ {NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
400
+ {NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop_aligned},
401
+ {NPY_METH_unaligned_strided_loop, (void *)&quad_matmul_strided_loop_unaligned},
402
+ {0 , NULL }};
403
+ #else
302
404
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
303
405
{NPY_METH_strided_loop, (void *)&naive_matmul_strided_loop},
304
406
{NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop},
305
407
{0 , NULL }};
306
- #endif // DISABLE_QUADBLAS
408
+ #endif // DISABLE_QUADBLAS
307
409
308
410
PyArrayMethod_Spec Spec = {
309
411
.name = " quad_matmul_qblas" ,
@@ -335,7 +437,7 @@ init_matmul_ops(PyObject *numpy)
335
437
}
336
438
337
439
if (PyUFunc_AddPromoter (ufunc, DTypes, promoter_capsule) < 0 ) {
338
- PyErr_Clear (); // Don't fail if promoter fails
440
+ PyErr_Clear ();
339
441
}
340
442
else {
341
443
}
0 commit comments