Skip to content

Commit 70ca644

Browse files
committed
shifting to matmul ufunc
1 parent 09918a3 commit 70ca644

File tree

3 files changed

+66
-73
lines changed

3 files changed

+66
-73
lines changed

quaddtype/numpy_quaddtype/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
QuadPrecDType,
44
is_longdouble_128,
55
get_sleef_constant,
6-
qblas_dot as dot,
76
set_num_threads,
87
get_num_threads,
98
get_quadblas_version
@@ -17,7 +16,7 @@
1716
# Constants
1817
'pi', 'e', 'log2e', 'log10e', 'ln2', 'ln10', 'max_value', 'min_value', 'epsilon',
1918
# QuadBLAS related functions
20-
'dot', 'set_num_threads', 'get_num_threads', 'get_quadblas_version'
19+
'set_num_threads', 'get_num_threads', 'get_quadblas_version'
2120
]
2221

2322
def SleefQuadPrecision(value):

quaddtype/numpy_quaddtype/src/quaddtype_main.c

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,45 +19,55 @@
1919
#include "quadblas_interface.h"
2020
#include "float.h"
2121

22-
23-
static PyObject* py_is_longdouble_128(PyObject* self, PyObject* args) {
24-
if(sizeof(long double) == 16 &&
25-
LDBL_MANT_DIG == 113 &&
26-
LDBL_MAX_EXP == 16384) {
22+
static PyObject *
23+
py_is_longdouble_128(PyObject *self, PyObject *args)
24+
{
25+
if (sizeof(long double) == 16 && LDBL_MANT_DIG == 113 && LDBL_MAX_EXP == 16384) {
2726
Py_RETURN_TRUE;
28-
} else {
27+
}
28+
else {
2929
Py_RETURN_FALSE;
3030
}
3131
}
3232

33-
static PyObject* get_sleef_constant(PyObject* self, PyObject* args) {
34-
const char* constant_name;
33+
static PyObject *
34+
get_sleef_constant(PyObject *self, PyObject *args)
35+
{
36+
const char *constant_name;
3537
if (!PyArg_ParseTuple(args, "s", &constant_name)) {
3638
return NULL;
3739
}
3840

39-
QuadPrecisionObject* result = QuadPrecision_raw_new(BACKEND_SLEEF);
41+
QuadPrecisionObject *result = QuadPrecision_raw_new(BACKEND_SLEEF);
4042
if (result == NULL) {
4143
return NULL;
4244
}
4345

4446
if (strcmp(constant_name, "pi") == 0) {
4547
result->value.sleef_value = SLEEF_M_PIq;
46-
} else if (strcmp(constant_name, "e") == 0) {
48+
}
49+
else if (strcmp(constant_name, "e") == 0) {
4750
result->value.sleef_value = SLEEF_M_Eq;
48-
} else if (strcmp(constant_name, "log2e") == 0) {
51+
}
52+
else if (strcmp(constant_name, "log2e") == 0) {
4953
result->value.sleef_value = SLEEF_M_LOG2Eq;
50-
} else if (strcmp(constant_name, "log10e") == 0) {
54+
}
55+
else if (strcmp(constant_name, "log10e") == 0) {
5156
result->value.sleef_value = SLEEF_M_LOG10Eq;
52-
} else if (strcmp(constant_name, "ln2") == 0) {
57+
}
58+
else if (strcmp(constant_name, "ln2") == 0) {
5359
result->value.sleef_value = SLEEF_M_LN2q;
54-
} else if (strcmp(constant_name, "ln10") == 0) {
60+
}
61+
else if (strcmp(constant_name, "ln10") == 0) {
5562
result->value.sleef_value = SLEEF_M_LN10q;
56-
} else if (strcmp(constant_name, "quad_max") == 0) {
63+
}
64+
else if (strcmp(constant_name, "quad_max") == 0) {
5765
result->value.sleef_value = SLEEF_QUAD_MAX;
58-
} else if (strcmp(constant_name, "quad_min") == 0) {
66+
}
67+
else if (strcmp(constant_name, "quad_min") == 0) {
5968
result->value.sleef_value = SLEEF_QUAD_MIN;
60-
} else if (strcmp(constant_name, "epsilon") == 0) {
69+
}
70+
else if (strcmp(constant_name, "epsilon") == 0) {
6171
result->value.sleef_value = SLEEF_QUAD_EPSILON;
6272
}
6373
else {
@@ -66,26 +76,23 @@ static PyObject* get_sleef_constant(PyObject* self, PyObject* args) {
6676
return NULL;
6777
}
6878

69-
return (PyObject*)result;
79+
return (PyObject *)result;
7080
}
7181

7282
static PyMethodDef module_methods[] = {
73-
{"is_longdouble_128", py_is_longdouble_128, METH_NOARGS, "Check if long double is 128-bit"},
74-
{"get_sleef_constant", get_sleef_constant, METH_VARARGS, "Get Sleef constant by name"},
75-
{"qblas_dot", py_quadblas_dot, METH_VARARGS, "Optimized dot product using QuadBLAS"},
76-
{"set_num_threads", py_quadblas_set_num_threads, METH_VARARGS, "Set number of threads for QuadBLAS"},
77-
{"get_num_threads", py_quadblas_get_num_threads, METH_NOARGS, "Get number of threads for QuadBLAS"},
78-
{"get_quadblas_version", py_quadblas_get_version, METH_NOARGS, "Get QuadBLAS version"},
79-
{NULL, NULL, 0, NULL}
80-
};
83+
{"is_longdouble_128", py_is_longdouble_128, METH_NOARGS, "Check if long double is 128-bit"},
84+
{"get_sleef_constant", get_sleef_constant, METH_VARARGS, "Get Sleef constant by name"},
85+
{"set_num_threads", py_quadblas_set_num_threads, METH_VARARGS,
86+
"Set number of threads for QuadBLAS"},
87+
{"get_num_threads", py_quadblas_get_num_threads, METH_NOARGS,
88+
"Get number of threads for QuadBLAS"},
89+
{"get_quadblas_version", py_quadblas_get_version, METH_NOARGS, "Get QuadBLAS version"},
90+
{NULL, NULL, 0, NULL}};
8191

8292
static struct PyModuleDef moduledef = {
83-
PyModuleDef_HEAD_INIT,
84-
.m_name = "_quaddtype_main",
93+
PyModuleDef_HEAD_INIT, .m_name = "_quaddtype_main",
8594
.m_doc = "Quad (128-bit) floating point Data Type for NumPy with multiple backends",
86-
.m_size = -1,
87-
.m_methods = module_methods
88-
};
95+
.m_size = -1, .m_methods = module_methods};
8996

9097
PyMODINIT_FUNC
9198
PyInit__quaddtype_main(void)

quaddtype/tests/test_dot.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,6 @@
1-
"""
2-
Focused test suite for the dot function in numpy_quaddtype
3-
4-
This module tests the QuadBLAS dot function for:
5-
- Vector-vector dot products
6-
- Matrix-vector multiplication
7-
- Matrix-matrix multiplication
8-
- Small and large matrix operations
9-
- Basic correctness validation
10-
11-
Uses only the Sleef backend for simplicity.
12-
"""
13-
141
import pytest
152
import numpy as np
16-
from numpy_quaddtype import QuadPrecision, QuadPrecDType, dot
3+
from numpy_quaddtype import QuadPrecision, QuadPrecDType
174

185

196
# ================================================================================
@@ -81,14 +68,14 @@ def create_quad_array(values, shape=None):
8168
# ================================================================================
8269

8370
class TestVectorVectorDot:
84-
"""Test vector-vector dot products"""
71+
"""Test vector-vector np.matmul products"""
8572

8673
def test_simple_dot_product(self):
87-
"""Test basic vector dot product"""
74+
"""Test basic vector np.matmul product"""
8875
x = create_quad_array([1, 2, 3])
8976
y = create_quad_array([4, 5, 6])
9077

91-
result = dot(x, y)
78+
result = np.matmul(x, y)
9279
expected = 1*4 + 2*5 + 3*6 # = 32
9380

9481
assert isinstance(result, QuadPrecision)
@@ -99,14 +86,14 @@ def test_orthogonal_vectors(self):
9986
x = create_quad_array([1, 0, 0])
10087
y = create_quad_array([0, 1, 0])
10188

102-
result = dot(x, y)
89+
result = np.matmul(x, y)
10390
assert_quad_equal(result, 0.0)
10491

10592
def test_same_vector(self):
106-
"""Test dot product of vector with itself"""
93+
"""Test np.matmul product of vector with itself"""
10794
x = create_quad_array([2, 3, 4])
10895

109-
result = dot(x, x)
96+
result = np.matmul(x, x)
11097
expected = 2*2 + 3*3 + 4*4 # = 29
11198

11299
assert_quad_equal(result, expected)
@@ -121,7 +108,7 @@ def test_various_vector_sizes(self, size):
121108
x = create_quad_array(x_vals)
122109
y = create_quad_array(y_vals)
123110

124-
result = dot(x, y)
111+
result = np.matmul(x, y)
125112
expected = sum(x_vals[i] * y_vals[i] for i in range(size))
126113

127114
assert_quad_equal(result, expected)
@@ -131,7 +118,7 @@ def test_negative_and_fractional_values(self):
131118
x = create_quad_array([1.5, -2.5, 3.25])
132119
y = create_quad_array([-1.25, 2.75, -3.5])
133120

134-
result = dot(x, y)
121+
result = np.matmul(x, y)
135122
expected = 1.5*(-1.25) + (-2.5)*2.75 + 3.25*(-3.5)
136123

137124
assert_quad_equal(result, expected)
@@ -151,7 +138,7 @@ def test_simple_matrix_vector(self):
151138
# 3x1 vector
152139
x = create_quad_array([1, 1, 1])
153140

154-
result = dot(A, x)
141+
result = np.matmul(A, x)
155142
expected = [1+2+3, 4+5+6] # [6, 15]
156143

157144
assert result.shape == (2,)
@@ -164,7 +151,7 @@ def test_identity_matrix_vector(self):
164151
I = create_quad_array([1, 0, 0, 0, 1, 0, 0, 0, 1], shape=(3, 3))
165152
x = create_quad_array([2, 3, 4])
166153

167-
result = dot(I, x)
154+
result = np.matmul(I, x)
168155

169156
assert result.shape == (3,)
170157
for i in range(3):
@@ -181,7 +168,7 @@ def test_various_matrix_vector_sizes(self, m, n):
181168
x_vals = [i + 1 for i in range(n)]
182169
x = create_quad_array(x_vals)
183170

184-
result = dot(A, x)
171+
result = np.matmul(A, x)
185172

186173
assert result.shape == (m,)
187174

@@ -205,7 +192,7 @@ def test_simple_matrix_matrix(self):
205192
A = create_quad_array([1, 2, 3, 4], shape=(2, 2))
206193
B = create_quad_array([5, 6, 7, 8], shape=(2, 2))
207194

208-
result = dot(A, B)
195+
result = np.matmul(A, B)
209196

210197
# Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]] = [[19, 22], [43, 50]]
211198
expected = [[19, 22], [43, 50]]
@@ -221,11 +208,11 @@ def test_identity_matrix_multiplication(self):
221208
I = create_quad_array([1, 0, 0, 1], shape=(2, 2))
222209

223210
# A * I should equal A
224-
result1 = dot(A, I)
211+
result1 = np.matmul(A, I)
225212
assert_quad_array_equal(result1, A)
226213

227214
# I * A should equal A
228-
result2 = dot(I, A)
215+
result2 = np.matmul(I, A)
229216
assert_quad_array_equal(result2, A)
230217

231218
@pytest.mark.parametrize("m,n,k", [(2,2,2), (2,3,4), (3,2,5), (4,4,4), (5,6,7)])
@@ -239,7 +226,7 @@ def test_various_matrix_sizes(self, m, n, k):
239226
B_vals = [(i*n + j + 1) for i in range(k) for j in range(n)]
240227
B = create_quad_array(B_vals, shape=(k, n))
241228

242-
result = dot(A, B)
229+
result = np.matmul(A, B)
243230

244231
assert result.shape == (m, n)
245232

@@ -258,12 +245,12 @@ def test_associativity(self):
258245
C = create_quad_array([1, 1, 2, 1], shape=(2, 2))
259246

260247
# Compute (A*B)*C
261-
AB = dot(A, B)
262-
result1 = dot(AB, C)
248+
AB = np.matmul(A, B)
249+
result1 = np.matmul(AB, C)
263250

264251
# Compute A*(B*C)
265-
BC = dot(B, C)
266-
result2 = dot(A, BC)
252+
BC = np.matmul(B, C)
253+
result2 = np.matmul(A, BC)
267254

268255
assert_quad_array_equal(result1, result2, rtol=1e-25)
269256

@@ -285,7 +272,7 @@ def test_large_square_matrices(self, size):
285272
A = create_quad_array(A_vals, shape=(size, size))
286273
B = create_quad_array(B_vals, shape=(size, size))
287274

288-
result = dot(A, B)
275+
result = np.matmul(A, B)
289276

290277
assert result.shape == (size, size)
291278

@@ -303,7 +290,7 @@ def test_large_square_matrices(self, size):
303290
assert_quad_equal(result[size//2, size//2], expected_value, rtol=1e-15, atol=1e-15)
304291

305292
def test_large_vector_operations(self):
306-
"""Test large vector dot products"""
293+
"""Test large vector np.matmul products"""
307294
size = 1000
308295

309296
# Create vectors with known sum
@@ -313,7 +300,7 @@ def test_large_vector_operations(self):
313300
x = create_quad_array(x_vals)
314301
y = create_quad_array(y_vals)
315302

316-
result = dot(x, y)
303+
result = np.matmul(x, y)
317304
expected = size * 1.0 * 2.0 # = 2000.0
318305

319306
assert_quad_equal(result, expected)
@@ -329,7 +316,7 @@ def test_rectangular_large_matrices(self):
329316
A = create_quad_array(A_vals, shape=(m, k))
330317
B = create_quad_array(B_vals, shape=(k, n))
331318

332-
result = dot(A, B)
319+
result = np.matmul(A, B)
333320

334321
assert result.shape == (m, n)
335322

@@ -354,23 +341,23 @@ def test_dimension_mismatch_vectors(self):
354341
y = create_quad_array([1, 2, 3])
355342

356343
with pytest.raises(ValueError, match="same length"):
357-
dot(x, y)
344+
np.matmul(x, y)
358345

359346
def test_dimension_mismatch_matrix_vector(self):
360347
"""Test dimension mismatch in matrix-vector"""
361348
A = create_quad_array([1, 2, 3, 4], shape=(2, 2))
362349
x = create_quad_array([1, 2, 3]) # Wrong size
363350

364351
with pytest.raises(ValueError, match="columns must match"):
365-
dot(A, x)
352+
np.matmul(A, x)
366353

367354
def test_dimension_mismatch_matrices(self):
368355
"""Test dimension mismatch in matrix-matrix"""
369356
A = create_quad_array([1, 2, 3, 4], shape=(2, 2))
370357
B = create_quad_array([1, 2, 3, 4, 5, 6], shape=(3, 2)) # Wrong size
371358

372359
with pytest.raises(ValueError, match="Matrix inner dimensions must match"):
373-
dot(A, B)
360+
np.matmul(A, B)
374361

375362

376363
if __name__ == "__main__":

0 commit comments

Comments
 (0)