Skip to content

Commit 742ce64

Browse files
committed
matmul ufunc completed, naive plugged, qblas experimental
1 parent 6800a90 commit 742ce64

File tree

3 files changed

+116
-77
lines changed

3 files changed

+116
-77
lines changed

quaddtype/numpy_quaddtype/src/quadblas_interface.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ qblas_gemm(char layout, char transa, char transb, size_t m, size_t n, size_t k,
7575
}
7676

7777
try {
78-
// Convert layout
7978
QuadBLAS::Layout qblas_layout;
8079
if (layout == 'R' || layout == 'r') {
8180
qblas_layout = QuadBLAS::Layout::RowMajor;
@@ -93,7 +92,6 @@ qblas_gemm(char layout, char transa, char transb, size_t m, size_t n, size_t k,
9392
return -1; // Transpose not implemented yet
9493
}
9594

96-
// Call QBLAS GEMM
9795
QuadBLAS::gemm(qblas_layout, m, n, k, *alpha, A, lda, B, ldb, *beta, C, ldc);
9896

9997
return 0;

quaddtype/numpy_quaddtype/src/umath/matmul.cpp

Lines changed: 115 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@ extern "C" {
2424
#include "promoters.hpp"
2525
#include "../quadblas_interface.h"
2626

27-
/**
28-
* Resolve descriptors for matmul operation.
29-
* Only supports SLEEF backend when QBLAS is enabled.
30-
*/
3127
static NPY_CASTING
3228
quad_matmul_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[],
3329
PyArray_Descr *const given_descrs[], PyArray_Descr *loop_descrs[],
@@ -76,23 +72,15 @@ quad_matmul_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[
7672
return casting;
7773
}
7874

79-
/**
80-
* Determine the type of operation based on input dimensions
81-
*/
8275
enum MatmulOperationType {
83-
MATMUL_DOT, // 1D x 1D -> scalar
84-
MATMUL_GEMV, // 2D x 1D -> 1D
85-
MATMUL_GEMM // 2D x 2D -> 2D
76+
MATMUL_DOT,
77+
MATMUL_GEMV,
78+
MATMUL_GEMM
8679
};
8780

8881
static MatmulOperationType
8982
determine_operation_type(npy_intp m, npy_intp n, npy_intp p)
9083
{
91-
// For matmul signature (m?,n),(n,p?)->(m?,p?):
92-
// - If m=1 and p=1: vector dot product (1D x 1D)
93-
// - If p=1: matrix-vector multiplication (2D x 1D)
94-
// - Otherwise: matrix-matrix multiplication (2D x 2D)
95-
9684
if (m == 1 && p == 1) {
9785
return MATMUL_DOT;
9886
}
@@ -104,10 +92,6 @@ determine_operation_type(npy_intp m, npy_intp n, npy_intp p)
10492
}
10593
}
10694

107-
/**
108-
* Matrix multiplication strided loop using QBLAS.
109-
* Automatically selects the appropriate QBLAS operation based on input dimensions.
110-
*/
11195
static int
11296
quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
11397
npy_intp const dimensions[], npy_intp const strides[], NpyAuxData *auxdata)
@@ -118,43 +102,29 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
118102
npy_intp n = dimensions[2]; // Cols of first matrix / rows of second matrix
119103
npy_intp p = dimensions[3]; // Cols of second matrix
120104

121-
// Extract batch strides
105+
// batch strides
122106
npy_intp A_stride = strides[0];
123107
npy_intp B_stride = strides[1];
124108
npy_intp C_stride = strides[2];
125109

126-
// Extract core strides for matrix dimensions
110+
// core strides for matrix dimensions
127111
npy_intp A_row_stride = strides[3];
128112
npy_intp A_col_stride = strides[4];
129113
npy_intp B_row_stride = strides[5];
130114
npy_intp B_col_stride = strides[6];
131115
npy_intp C_row_stride = strides[7];
132116
npy_intp C_col_stride = strides[8];
133117

134-
// Note: B_col_stride and C_col_stride not needed for row-major QBLAS calls
135-
136-
// Get backend from descriptor (should be SLEEF only)
137118
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
138119
if (descr->backend != BACKEND_SLEEF) {
139120
PyErr_SetString(PyExc_RuntimeError, "Internal error: non-SLEEF backend in QBLAS matmul");
140121
return -1;
141122
}
142123

143-
// Determine operation type
144124
MatmulOperationType op_type = determine_operation_type(m, n, p);
145-
146-
// Constants for QBLAS
147125
Sleef_quad alpha = Sleef_cast_from_doubleq1(1.0);
148126
Sleef_quad beta = Sleef_cast_from_doubleq1(0.0);
149127

150-
// print all information for debugging
151-
printf("DEBUG: Performing %ld batch operations with dimensions (%ld, %ld, %ld)\n", (long)N,
152-
(long)m, (long)n, (long)p);
153-
printf("DEBUG: Strides - A: (%ld, %ld), B: (%ld, %ld), C: (%ld, %ld)\n", (long)A_row_stride,
154-
(long)A_col_stride, (long)B_row_stride, (long)B_col_stride, (long)C_row_stride,
155-
(long)C_col_stride);
156-
printf("DEBUG: Operation type: %d\n", op_type);
157-
158128
char *A = data[0];
159129
char *B = data[1];
160130
char *C = data[2];
@@ -167,13 +137,6 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
167137

168138
switch (op_type) {
169139
case MATMUL_DOT: {
170-
// Vector dot product: C = A^T * B (both are vectors)
171-
// A has shape (1, n), B has shape (n, 1), C has shape (1, 1)
172-
173-
printf("DEBUG: Using QBLAS dot product for %ld elements\n", (long)n);
174-
175-
// A is effectively a vector of length n
176-
// B is effectively a vector of length n
177140
size_t incx = A_col_stride / sizeof(Sleef_quad);
178141
size_t incy = B_row_stride / sizeof(Sleef_quad);
179142

@@ -182,12 +145,6 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
182145
}
183146

184147
case MATMUL_GEMV: {
185-
// Matrix-vector multiplication: C = A * B
186-
// A has shape (m, n), B has shape (n, 1), C has shape (m, 1)
187-
188-
printf("DEBUG: Using QBLAS GEMV for %ldx%ld matrix times %ld vector\n", (long)m,
189-
(long)n, (long)n);
190-
191148
size_t lda = A_row_stride / sizeof(Sleef_quad);
192149
size_t incx = B_row_stride / sizeof(Sleef_quad);
193150
size_t incy = C_row_stride / sizeof(Sleef_quad);
@@ -198,17 +155,46 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
198155
}
199156

200157
case MATMUL_GEMM: {
201-
// Matrix-matrix multiplication: C = A * B
202-
// A has shape (m, n), B has shape (n, p), C has shape (m, p)
203-
204-
printf("DEBUG: Using QBLAS GEMM for %ldx%ldx%ld matrices\n", (long)m, (long)n, (long)p);
205-
206158
size_t lda = A_row_stride / sizeof(Sleef_quad);
207159
size_t ldb = B_row_stride / sizeof(Sleef_quad);
208-
size_t ldc = C_row_stride / sizeof(Sleef_quad);
160+
size_t ldc_numpy = C_row_stride / sizeof(Sleef_quad);
161+
162+
Sleef_quad *temp_A_buffer = new Sleef_quad[m * n];
163+
if (!temp_A_buffer) {
164+
PyErr_SetString(PyExc_MemoryError, "Failed to allocate temporary buffer for GEMM");
165+
delete[] temp_A_buffer;
166+
return -1;
167+
}
168+
Sleef_quad *temp_B_buffer = new Sleef_quad[n * p];
169+
if (!temp_B_buffer) {
170+
PyErr_SetString(PyExc_MemoryError, "Failed to allocate temporary buffer for GEMM");
171+
delete[] temp_A_buffer;
172+
return -1;
173+
}
174+
memcpy(temp_A_buffer, A_ptr, m * n * sizeof(Sleef_quad));
175+
memcpy(temp_B_buffer, B_ptr, n * p * sizeof(Sleef_quad));
176+
A_ptr = temp_A_buffer;
177+
B_ptr = temp_B_buffer;
178+
179+
Sleef_quad *temp_C_buffer = new Sleef_quad[m * p];
180+
if (!temp_C_buffer) {
181+
PyErr_SetString(PyExc_MemoryError,
182+
"Failed to allocate temporary buffer for GEMM result");
183+
return -1;
184+
}
185+
186+
size_t ldc_temp = p;
209187

210188
result = qblas_gemm('R', 'N', 'N', m, p, n, &alpha, A_ptr, lda, B_ptr, ldb, &beta,
211-
C_ptr, ldc);
189+
temp_C_buffer, ldc_temp);
190+
191+
if (result == 0) {
192+
memcpy(C_ptr, temp_C_buffer, m * p * sizeof(Sleef_quad));
193+
}
194+
195+
delete[] temp_C_buffer;
196+
delete[] temp_A_buffer;
197+
delete[] temp_B_buffer;
212198
break;
213199
}
214200
}
@@ -221,27 +207,91 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
221207
return 0;
222208
}
223209

224-
/**
225-
* Register matmul support with QBLAS acceleration
226-
*/
210+
static int
211+
naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
212+
npy_intp const dimensions[], npy_intp const strides[],
213+
NpyAuxData *auxdata)
214+
{
215+
npy_intp N = dimensions[0];
216+
npy_intp m = dimensions[1];
217+
npy_intp n = dimensions[2];
218+
npy_intp p = dimensions[3];
219+
220+
npy_intp A_batch_stride = strides[0];
221+
npy_intp B_batch_stride = strides[1];
222+
npy_intp C_batch_stride = strides[2];
223+
224+
npy_intp A_row_stride = strides[3];
225+
npy_intp A_col_stride = strides[4];
226+
npy_intp B_row_stride = strides[5];
227+
npy_intp B_col_stride = strides[6];
228+
npy_intp C_row_stride = strides[7];
229+
npy_intp C_col_stride = strides[8];
230+
231+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
232+
QuadBackendType backend = descr->backend;
233+
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
234+
235+
for (npy_intp batch = 0; batch < N; batch++) {
236+
char *A_batch = data[0] + batch * A_batch_stride;
237+
char *B_batch = data[1] + batch * B_batch_stride;
238+
char *C_batch = data[2] + batch * C_batch_stride;
239+
240+
for (npy_intp i = 0; i < m; i++) {
241+
for (npy_intp j = 0; j < p; j++) {
242+
char *C_ij = C_batch + i * C_row_stride + j * C_col_stride;
243+
244+
if (backend == BACKEND_SLEEF) {
245+
Sleef_quad sum = Sleef_cast_from_doubleq1(0.0);
246+
247+
for (npy_intp k = 0; k < n; k++) {
248+
char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
249+
char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
250+
251+
Sleef_quad a_val, b_val;
252+
memcpy(&a_val, A_ik, sizeof(Sleef_quad));
253+
memcpy(&b_val, B_kj, sizeof(Sleef_quad));
254+
sum = Sleef_fmaq1_u05(a_val, b_val, sum);
255+
}
256+
257+
memcpy(C_ij, &sum, sizeof(Sleef_quad));
258+
}
259+
else {
260+
long double sum = 0.0L;
261+
262+
for (npy_intp k = 0; k < n; k++) {
263+
char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
264+
char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
265+
266+
long double a_val, b_val;
267+
memcpy(&a_val, A_ik, sizeof(long double));
268+
memcpy(&b_val, B_kj, sizeof(long double));
269+
270+
sum += a_val * b_val;
271+
}
272+
273+
memcpy(C_ij, &sum, sizeof(long double));
274+
}
275+
}
276+
}
277+
}
278+
279+
return 0;
280+
}
281+
227282
int
228283
init_matmul_ops(PyObject *numpy)
229284
{
230-
printf("DEBUG: init_matmul_ops - registering QBLAS-accelerated matmul\n");
231-
232-
// Get the existing matmul ufunc
233285
PyObject *ufunc = PyObject_GetAttrString(numpy, "matmul");
234286
if (ufunc == NULL) {
235-
printf("DEBUG: Failed to get numpy.matmul\n");
236287
return -1;
237288
}
238289

239-
// Setup method specification for QBLAS-accelerated matmul
240290
PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
241291

242292
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
243-
{NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop},
244-
{NPY_METH_unaligned_strided_loop, (void *)&quad_matmul_strided_loop},
293+
{NPY_METH_strided_loop, (void *)&naive_matmul_strided_loop},
294+
{NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop},
245295
{0, NULL}};
246296

247297
PyArrayMethod_Spec Spec = {
@@ -254,17 +304,11 @@ init_matmul_ops(PyObject *numpy)
254304
.slots = slots,
255305
};
256306

257-
printf("DEBUG: About to add QBLAS loop to matmul ufunc...\n");
258-
259307
if (PyUFunc_AddLoopFromSpec(ufunc, &Spec) < 0) {
260-
printf("DEBUG: Failed to add QBLAS loop to matmul ufunc\n");
261308
Py_DECREF(ufunc);
262309
return -1;
263310
}
264311

265-
printf("DEBUG: Successfully added QBLAS matmul loop!\n");
266-
267-
// Add promoter
268312
PyObject *promoter_capsule =
269313
PyCapsule_New((void *)&quad_ufunc_promoter, "numpy._ufunc_promoter", NULL);
270314
if (promoter_capsule == NULL) {
@@ -280,17 +324,14 @@ init_matmul_ops(PyObject *numpy)
280324
}
281325

282326
if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
283-
printf("DEBUG: Failed to add promoter (continuing anyway)\n");
284327
PyErr_Clear(); // Don't fail if promoter fails
285328
}
286329
else {
287-
printf("DEBUG: Successfully added promoter\n");
288330
}
289331

290332
Py_DECREF(DTypes);
291333
Py_DECREF(promoter_capsule);
292334
Py_DECREF(ufunc);
293335

294-
printf("DEBUG: init_matmul_ops completed successfully with QBLAS acceleration\n");
295336
return 0;
296337
}

0 commit comments

Comments
 (0)