Skip to content

Commit bbce2ac

Browse files
committed
ahh stupid me :), fallback to naive for MSVC
1 parent c518a29 commit bbce2ac

File tree

4 files changed

+110
-6
lines changed

4 files changed

+110
-6
lines changed

quaddtype/numpy_quaddtype/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,4 @@ def LongDoubleQuadPrecDType():
3939
ln10 = get_sleef_constant("ln10")
4040
max_value = get_sleef_constant("quad_max")
4141
min_value = get_sleef_constant("quad_min")
42-
epsilon = get_sleef_constant("epsilon")
43-
44-
num_cores = multiprocessing.cpu_count()
45-
# set default number of threads for QuadBLAS
46-
set_num_threads(num_cores)
42+
epsilon = get_sleef_constant("epsilon")

quaddtype/numpy_quaddtype/src/quadblas_interface.cpp

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
#include "quadblas_interface.h"
2-
#include "../QBLAS/include/quadblas/quadblas.hpp"
32
#include <cstring>
43
#include <algorithm>
54

5+
#ifndef DISABLE_QUADBLAS
6+
#include "../QBLAS/include/quadblas/quadblas.hpp"
7+
#endif // DISABLE_QUADBLAS
8+
69
extern "C" {
710

11+
12+
#ifndef DISABLE_QUADBLAS
13+
814
int
915
qblas_dot(size_t n, Sleef_quad *x, size_t incx, Sleef_quad *y, size_t incy, Sleef_quad *result)
1016
{
@@ -138,4 +144,90 @@ py_quadblas_get_version(PyObject *self, PyObject *args)
138144
return PyUnicode_FromString("QuadBLAS 1.0.0 - High Performance Quad Precision BLAS");
139145
}
140146

147+
int
148+
quadblas_set_num_threads(int num_threads)
149+
{
150+
QuadBLAS::set_num_threads(num_threads);
151+
return 0;
152+
}
153+
154+
int
155+
quadblas_get_num_threads(void)
156+
{
157+
int num_threads = QuadBLAS::get_num_threads();
158+
return num_threads;
159+
}
160+
161+
#else // DISABLE_QUADBLAS
162+
163+
164+
int
165+
qblas_dot(size_t n, Sleef_quad *x, size_t incx, Sleef_quad *y, size_t incy, Sleef_quad *result)
166+
{
167+
return -1; // QBLAS is disabled, dot product not available
168+
}
169+
170+
int
171+
qblas_gemv(char layout, char trans, size_t m, size_t n, Sleef_quad *alpha, Sleef_quad *A,
172+
size_t lda, Sleef_quad *x, size_t incx, Sleef_quad *beta, Sleef_quad *y, size_t incy)
173+
{
174+
return -1; // QBLAS is disabled, GEMV not available
175+
}
176+
177+
int
178+
qblas_gemm(char layout, char transa, char transb, size_t m, size_t n, size_t k, Sleef_quad *alpha,
179+
Sleef_quad *A, size_t lda, Sleef_quad *B, size_t ldb, Sleef_quad *beta, Sleef_quad *C,
180+
size_t ldc)
181+
{
182+
return -1; // QBLAS is disabled, GEMM not available
183+
}
184+
185+
int
186+
qblas_supports_backend(QuadBackendType backend)
187+
{
188+
return -1; // QBLAS is disabled, backend support not available
189+
}
190+
191+
PyObject *
192+
py_quadblas_set_num_threads(PyObject *self, PyObject *args)
193+
{
194+
// raise error
195+
PyErr_SetString(PyExc_NotImplementedError, "QuadBLAS is disabled");
196+
return NULL;
197+
}
198+
199+
PyObject *
200+
py_quadblas_get_num_threads(PyObject *self, PyObject *args)
201+
{
202+
// raise error
203+
PyErr_SetString(PyExc_NotImplementedError, "QuadBLAS is disabled");
204+
return NULL;
205+
}
206+
207+
PyObject *
208+
py_quadblas_get_version(PyObject *self, PyObject *args)
209+
{
210+
// raise error
211+
PyErr_SetString(PyExc_NotImplementedError, "QuadBLAS is disabled");
212+
return NULL;
213+
}
214+
215+
int
216+
quadblas_set_num_threads(int num_threads)
217+
{
218+
// raise error
219+
PyErr_SetString(PyExc_NotImplementedError, "QuadBLAS is disabled");
220+
return -1;
221+
}
222+
223+
int
224+
quadblas_get_num_threads(void)
225+
{
226+
// raise error
227+
PyErr_SetString(PyExc_NotImplementedError, "QuadBLAS is disabled");
228+
return -1;
229+
}
230+
231+
#endif // DISABLE_QUADBLAS
232+
141233
} // extern "C"

quaddtype/numpy_quaddtype/src/quadblas_interface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ py_quadblas_get_num_threads(PyObject *self, PyObject *args);
3232
PyObject *
3333
py_quadblas_get_version(PyObject *self, PyObject *args);
3434

35+
int
36+
quadblas_set_num_threads(int num_threads);
37+
int
38+
quadblas_get_num_threads(void);
39+
3540
#ifdef __cplusplus
3641
}
3742
#endif

quaddtype/numpy_quaddtype/src/umath/matmul.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,21 @@ init_matmul_ops(PyObject *numpy)
289289

290290
PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
291291

292+
#ifndef DISABLE_QUADBLAS
293+
// set threading to max
294+
int num_threads = quadblas_get_num_threads();
295+
quadblas_set_num_threads(num_threads);
296+
292297
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
293298
{NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop},
294299
{NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop},
295300
{0, NULL}};
301+
#else
302+
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
303+
{NPY_METH_strided_loop, (void *)&naive_matmul_strided_loop},
304+
{NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop},
305+
{0, NULL}};
306+
#endif // DISABLE_QUADBLAS
296307

297308
PyArrayMethod_Spec Spec = {
298309
.name = "quad_matmul_qblas",

0 commit comments

Comments
 (0)