Skip to content

Commit cec5ace

Browse files
committed
this should fix them all
1 parent 5e5fa65 commit cec5ace

File tree

2 files changed

+160
-58
lines changed

2 files changed

+160
-58
lines changed

quaddtype/numpy_quaddtype/src/umath/matmul.cpp

Lines changed: 159 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ determine_operation_type(npy_intp m, npy_intp n, npy_intp p)
9393
}
9494

9595
static int
96-
quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
97-
npy_intp const dimensions[], npy_intp const strides[], NpyAuxData *auxdata)
96+
quad_matmul_strided_loop_aligned(PyArrayMethod_Context *context, char *const data[],
97+
npy_intp const dimensions[], npy_intp const strides[],
98+
NpyAuxData *auxdata)
9899
{
99100
// Extract dimensions
100101
npy_intp N = dimensions[0]; // Batch size, this remains always 1 for matmul afaik
@@ -149,6 +150,8 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
149150
size_t incx = B_row_stride / sizeof(Sleef_quad);
150151
size_t incy = C_row_stride / sizeof(Sleef_quad);
151152

153+
memset(C_ptr, 0, m * p * sizeof(Sleef_quad));
154+
152155
result =
153156
qblas_gemv('R', 'N', m, n, &alpha, A_ptr, lda, B_ptr, incx, &beta, C_ptr, incy);
154157
break;
@@ -159,32 +162,132 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
159162
size_t ldb = B_row_stride / sizeof(Sleef_quad);
160163
size_t ldc_numpy = C_row_stride / sizeof(Sleef_quad);
161164

165+
memset(C_ptr, 0, m * p * sizeof(Sleef_quad));
166+
167+
size_t ldc_temp = p;
168+
169+
result = qblas_gemm('R', 'N', 'N', m, p, n, &alpha, A_ptr, lda, B_ptr, ldb, &beta,
170+
C_ptr, ldc_numpy);
171+
break;
172+
}
173+
}
174+
175+
if (result != 0) {
176+
PyErr_SetString(PyExc_RuntimeError, "QBLAS operation failed");
177+
return -1;
178+
}
179+
180+
return 0;
181+
}
182+
183+
static int
184+
quad_matmul_strided_loop_unaligned(PyArrayMethod_Context *context, char *const data[],
185+
npy_intp const dimensions[], npy_intp const strides[],
186+
NpyAuxData *auxdata)
187+
{
188+
// Extract dimensions
189+
npy_intp N = dimensions[0]; // Batch size, this remains always 1 for matmul afaik
190+
npy_intp m = dimensions[1]; // Rows of first matrix
191+
npy_intp n = dimensions[2]; // Cols of first matrix / rows of second matrix
192+
npy_intp p = dimensions[3]; // Cols of second matrix
193+
194+
// batch strides
195+
npy_intp A_stride = strides[0];
196+
npy_intp B_stride = strides[1];
197+
npy_intp C_stride = strides[2];
198+
199+
// core strides for matrix dimensions
200+
npy_intp A_row_stride = strides[3];
201+
npy_intp A_col_stride = strides[4];
202+
npy_intp B_row_stride = strides[5];
203+
npy_intp B_col_stride = strides[6];
204+
npy_intp C_row_stride = strides[7];
205+
npy_intp C_col_stride = strides[8];
206+
207+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
208+
if (descr->backend != BACKEND_SLEEF) {
209+
PyErr_SetString(PyExc_RuntimeError, "Internal error: non-SLEEF backend in QBLAS matmul");
210+
return -1;
211+
}
212+
213+
MatmulOperationType op_type = determine_operation_type(m, n, p);
214+
Sleef_quad alpha = Sleef_cast_from_doubleq1(1.0);
215+
Sleef_quad beta = Sleef_cast_from_doubleq1(0.0);
216+
217+
char *A = data[0];
218+
char *B = data[1];
219+
char *C = data[2];
220+
221+
Sleef_quad *A_ptr = (Sleef_quad *)A;
222+
Sleef_quad *B_ptr = (Sleef_quad *)B;
223+
Sleef_quad *C_ptr = (Sleef_quad *)C;
224+
225+
int result = -1;
226+
227+
switch (op_type) {
228+
case MATMUL_DOT: {
229+
Sleef_quad *temp_A_buffer = new Sleef_quad[n];
230+
Sleef_quad *temp_B_buffer = new Sleef_quad[n];
231+
232+
memcpy(temp_A_buffer, A_ptr, n * sizeof(Sleef_quad));
233+
memcpy(temp_B_buffer, B_ptr, n * sizeof(Sleef_quad));
234+
235+
size_t incx = 1;
236+
size_t incy = 1;
237+
238+
result = qblas_dot(n, temp_A_buffer, incx, temp_B_buffer, incy, C_ptr);
239+
240+
delete[] temp_A_buffer;
241+
delete[] temp_B_buffer;
242+
break;
243+
}
244+
245+
case MATMUL_GEMV: {
246+
size_t lda = A_row_stride / sizeof(Sleef_quad);
247+
size_t incx = B_row_stride / sizeof(Sleef_quad);
248+
size_t incy = C_row_stride / sizeof(Sleef_quad);
249+
162250
Sleef_quad *temp_A_buffer = new Sleef_quad[m * n];
163-
if (!temp_A_buffer) {
164-
PyErr_SetString(PyExc_MemoryError, "Failed to allocate temporary buffer for GEMM");
165-
delete[] temp_A_buffer;
166-
return -1;
167-
}
168251
Sleef_quad *temp_B_buffer = new Sleef_quad[n * p];
169-
if (!temp_B_buffer) {
170-
PyErr_SetString(PyExc_MemoryError, "Failed to allocate temporary buffer for GEMM");
171-
delete[] temp_A_buffer;
172-
return -1;
173-
}
174252
memcpy(temp_A_buffer, A_ptr, m * n * sizeof(Sleef_quad));
175253
memcpy(temp_B_buffer, B_ptr, n * p * sizeof(Sleef_quad));
176254
A_ptr = temp_A_buffer;
177255
B_ptr = temp_B_buffer;
178256

257+
// Use temp_C_buffer to avoid unaligned writes
179258
Sleef_quad *temp_C_buffer = new Sleef_quad[m * p];
180-
if (!temp_C_buffer) {
181-
PyErr_SetString(PyExc_MemoryError,
182-
"Failed to allocate temporary buffer for GEMM result");
183-
return -1;
184-
}
185259

260+
lda = n;
261+
incx = 1;
262+
incy = 1;
263+
264+
memset(temp_C_buffer, 0, m * p * sizeof(Sleef_quad));
265+
266+
result = qblas_gemv('R', 'N', m, n, &alpha, A_ptr, lda, B_ptr, incx, &beta,
267+
temp_C_buffer, incy);
268+
break;
269+
}
270+
271+
case MATMUL_GEMM: {
272+
size_t lda = A_row_stride / sizeof(Sleef_quad);
273+
size_t ldb = B_row_stride / sizeof(Sleef_quad);
274+
size_t ldc_numpy = C_row_stride / sizeof(Sleef_quad);
275+
276+
Sleef_quad *temp_A_buffer = new Sleef_quad[m * n];
277+
Sleef_quad *temp_B_buffer = new Sleef_quad[n * p];
278+
memcpy(temp_A_buffer, A_ptr, m * n * sizeof(Sleef_quad));
279+
memcpy(temp_B_buffer, B_ptr, n * p * sizeof(Sleef_quad));
280+
A_ptr = temp_A_buffer;
281+
B_ptr = temp_B_buffer;
282+
283+
// since these are now contiguous so,
284+
lda = n;
285+
ldb = p;
186286
size_t ldc_temp = p;
187287

288+
Sleef_quad *temp_C_buffer = new Sleef_quad[m * p];
289+
memset(temp_C_buffer, 0, m * p * sizeof(Sleef_quad));
290+
188291
result = qblas_gemm('R', 'N', 'N', m, p, n, &alpha, A_ptr, lda, B_ptr, ldb, &beta,
189292
temp_C_buffer, ldc_temp);
190293

@@ -218,8 +321,8 @@ naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
218321
npy_intp p = dimensions[3];
219322

220323
npy_intp A_batch_stride = strides[0];
221-
npy_intp B_batch_stride = strides[1];
222-
npy_intp C_batch_stride = strides[2];
324+
npy_intp B_stride = strides[1];
325+
npy_intp C_stride = strides[2];
223326

224327
npy_intp A_row_stride = strides[3];
225328
npy_intp A_col_stride = strides[4];
@@ -232,46 +335,44 @@ naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
232335
QuadBackendType backend = descr->backend;
233336
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
234337

235-
for (npy_intp batch = 0; batch < N; batch++) {
236-
char *A_batch = data[0] + batch * A_batch_stride;
237-
char *B_batch = data[1] + batch * B_batch_stride;
238-
char *C_batch = data[2] + batch * C_batch_stride;
239-
240-
for (npy_intp i = 0; i < m; i++) {
241-
for (npy_intp j = 0; j < p; j++) {
242-
char *C_ij = C_batch + i * C_row_stride + j * C_col_stride;
338+
char *A = data[0];
339+
char *B = data[1];
340+
char *C = data[2];
243341

244-
if (backend == BACKEND_SLEEF) {
245-
Sleef_quad sum = Sleef_cast_from_doubleq1(0.0);
342+
for (npy_intp i = 0; i < m; i++) {
343+
for (npy_intp j = 0; j < p; j++) {
344+
char *C_ij = C + i * C_row_stride + j * C_col_stride;
246345

247-
for (npy_intp k = 0; k < n; k++) {
248-
char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
249-
char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
346+
if (backend == BACKEND_SLEEF) {
347+
Sleef_quad sum = Sleef_cast_from_doubleq1(0.0);
250348

251-
Sleef_quad a_val, b_val;
252-
memcpy(&a_val, A_ik, sizeof(Sleef_quad));
253-
memcpy(&b_val, B_kj, sizeof(Sleef_quad));
254-
sum = Sleef_fmaq1_u05(a_val, b_val, sum);
255-
}
349+
for (npy_intp k = 0; k < n; k++) {
350+
char *A_ik = A + i * A_row_stride + k * A_col_stride;
351+
char *B_kj = B + k * B_row_stride + j * B_col_stride;
256352

257-
memcpy(C_ij, &sum, sizeof(Sleef_quad));
353+
Sleef_quad a_val, b_val;
354+
memcpy(&a_val, A_ik, sizeof(Sleef_quad));
355+
memcpy(&b_val, B_kj, sizeof(Sleef_quad));
356+
sum = Sleef_fmaq1_u05(a_val, b_val, sum);
258357
}
259-
else {
260-
long double sum = 0.0L;
261358

262-
for (npy_intp k = 0; k < n; k++) {
263-
char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
264-
char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
359+
memcpy(C_ij, &sum, sizeof(Sleef_quad));
360+
}
361+
else {
362+
long double sum = 0.0L;
265363

266-
long double a_val, b_val;
267-
memcpy(&a_val, A_ik, sizeof(long double));
268-
memcpy(&b_val, B_kj, sizeof(long double));
364+
for (npy_intp k = 0; k < n; k++) {
365+
char *A_ik = A + i * A_row_stride + k * A_col_stride;
366+
char *B_kj = B + k * B_row_stride + j * B_col_stride;
269367

270-
sum += a_val * b_val;
271-
}
368+
long double a_val, b_val;
369+
memcpy(&a_val, A_ik, sizeof(long double));
370+
memcpy(&b_val, B_kj, sizeof(long double));
272371

273-
memcpy(C_ij, &sum, sizeof(long double));
372+
sum += a_val * b_val;
274373
}
374+
375+
memcpy(C_ij, &sum, sizeof(long double));
275376
}
276377
}
277378
}
@@ -289,21 +390,22 @@ init_matmul_ops(PyObject *numpy)
289390

290391
PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
291392

292-
#ifndef DISABLE_QUADBLAS
393+
#ifndef DISABLE_QUADBLAS
293394
// set threading to max
294395
int num_threads = _quadblas_get_num_threads();
295396
_quadblas_set_num_threads(num_threads);
296397

297-
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
298-
{NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop},
299-
{NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop},
300-
{0, NULL}};
301-
#else
398+
PyType_Slot slots[] = {
399+
{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
400+
{NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop_aligned},
401+
{NPY_METH_unaligned_strided_loop, (void *)&quad_matmul_strided_loop_unaligned},
402+
{0, NULL}};
403+
#else
302404
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
303405
{NPY_METH_strided_loop, (void *)&naive_matmul_strided_loop},
304406
{NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop},
305407
{0, NULL}};
306-
#endif // DISABLE_QUADBLAS
408+
#endif // DISABLE_QUADBLAS
307409

308410
PyArrayMethod_Spec Spec = {
309411
.name = "quad_matmul_qblas",
@@ -335,7 +437,7 @@ init_matmul_ops(PyObject *numpy)
335437
}
336438

337439
if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
338-
PyErr_Clear(); // Don't fail if promoter fails
440+
PyErr_Clear();
339441
}
340442
else {
341443
}

0 commit comments

Comments
 (0)