Skip to content

Commit 03ce970

Browse files
committed
Make GEMV more robust to zero strided inputs
1 parent bf628c9 commit 03ce970

File tree

2 files changed

+141
-97
lines changed

2 files changed

+141
-97
lines changed

pytensor/tensor/blas_c.py

Lines changed: 73 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -413,15 +413,15 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
413413
char NOTRANS = 'N';
414414
int NA0 = PyArray_DIMS(%(A)s)[0];
415415
int NA1 = PyArray_DIMS(%(A)s)[1];
416-
/* This formula is needed in the case where A is actually a row or
417-
* column matrix, because BLAS sometimes insists that the strides:
418-
* - are not smaller than the number of elements in the array
419-
* - are not 0.
416+
int Nx = PyArray_DIMS(%(x)s)[0];
417+
/* If A or x have length 1 dimensions, the respective strides don't matter
418+
* However, BLAS often insists that the strides be not zero nor smaller than
419+
* the number of elements in the array. We set them to 1 arbitrarily;
420420
*/
421-
int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
422-
int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
421+
int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : 1;
422+
int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : 1;
423+
int Sx = (Nx > 1) ? PyArray_STRIDES(%(x)s)[0] / elemsize: 1;
423424
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
424-
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
425425
426426
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
427427
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
@@ -435,62 +435,49 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
435435
436436
if (NA0 * NA1)
437437
{
438-
// If A is neither C- nor F-contiguous, we make a copy.
439-
// TODO:
440-
// - if one stride is equal to "- elemsize", we can still call
441-
// gemv on reversed matrix and vectors
442-
// - if the copy is too long, maybe call vector/vector dot on
443-
// each row instead
444-
if ((PyArray_STRIDES(%(A)s)[0] < 0)
445-
|| (PyArray_STRIDES(%(A)s)[1] < 0)
446-
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize)
447-
&& (PyArray_STRIDES(%(A)s)[1] != elemsize)))
438+
// Non-empty branch
439+
440+
if (Sx == 0)
448441
{
449-
npy_intp dims[2];
450-
dims[0] = NA0;
451-
dims[1] = NA1;
442+
// This is a broadcasted vector with length > 1 and a stride of 0.
443+
// We need to make a full copy of it.
452444
453-
PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(
454-
%(A)s);
455-
if (!A_copy)
445+
PyArrayObject * x_copy = (PyArrayObject *) PyArray_Copy(%(x)s);
446+
if (!x_copy)
456447
%(fail)s
457-
Py_XDECREF(%(A)s);
458-
%(A)s = A_copy;
459-
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
460-
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
448+
Py_XDECREF(%(x)s);
449+
%(x)s = x_copy;
450+
x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
451+
Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
461452
}
462453
463-
if (PyArray_STRIDES(%(A)s)[0] == elemsize)
454+
if (
455+
(PyArray_STRIDES(%(A)s)[0] < 0) || (PyArray_STRIDES(%(A)s)[1] < 0)
456+
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize) && (PyArray_STRIDES(%(A)s)[1] != elemsize))
457+
)
464458
{
465-
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
466-
{
467-
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
468-
sgemv_(&NOTRANS, &NA0, &NA1,
469-
&alpha,
470-
(float*)(PyArray_DATA(%(A)s)), &SA1,
471-
(float*)x_data, &Sx,
472-
&fbeta,
473-
(float*)z_data, &Sz);
474-
}
475-
else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
476-
{
477-
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
478-
dgemv_(&NOTRANS, &NA0, &NA1,
479-
&alpha,
480-
(double*)(PyArray_DATA(%(A)s)), &SA1,
481-
(double*)x_data, &Sx,
482-
&dbeta,
483-
(double*)z_data, &Sz);
484-
}
485-
else
486-
{
487-
PyErr_SetString(PyExc_AssertionError,
488-
"neither float nor double dtype");
459+
// If A is neither C- nor F-contiguous, we make a copy.
460+
// TODO:
461+
// - if one stride is equal to "- elemsize", we can still call
462+
// gemv on reversed matrix and vectors
463+
// - if the copy is too long, maybe call vector/vector dot on
464+
// each row instead
465+
PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(%(A)s);
466+
if (!A_copy)
489467
%(fail)s
490-
}
468+
Py_XDECREF(%(A)s);
469+
%(A)s = A_copy;
470+
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : 1;
471+
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : 1;
491472
}
492-
else if (PyArray_STRIDES(%(A)s)[1] == elemsize)
473+
474+
475+
if (PyArray_STRIDES(%(A)s)[1] == elemsize)
493476
{
477+
// C-contiguous branch
478+
// May also be F-contiguous, but we give preference to it,
479+
// because it has special handling for the A row/col matrix
480+
494481
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
495482
{
496483
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
@@ -554,10 +541,40 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
554541
%(fail)s
555542
}
556543
}
544+
else if (PyArray_STRIDES(%(A)s)[0] == elemsize)
545+
{
546+
// Fortran order branch
547+
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
548+
{
549+
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
550+
sgemv_(&NOTRANS, &NA0, &NA1,
551+
&alpha,
552+
(float*)(PyArray_DATA(%(A)s)), &SA1,
553+
(float*)x_data, &Sx,
554+
&fbeta,
555+
(float*)z_data, &Sz);
556+
}
557+
else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
558+
{
559+
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
560+
dgemv_(&NOTRANS, &NA0, &NA1,
561+
&alpha,
562+
(double*)(PyArray_DATA(%(A)s)), &SA1,
563+
(double*)x_data, &Sx,
564+
&dbeta,
565+
(double*)z_data, &Sz);
566+
}
567+
else
568+
{
569+
PyErr_SetString(PyExc_AssertionError,
570+
"neither float nor double dtype");
571+
%(fail)s
572+
}
573+
}
557574
else
558575
{
559576
PyErr_SetString(PyExc_AssertionError,
560-
"xx is a double-strided matrix, and should have been "
577+
"A is a double-strided matrix, and should have been "
561578
"copied into a memory-contiguous one.");
562579
%(fail)s
563580
}
@@ -603,6 +620,7 @@ def c_code(self, node, name, inp, out, sub):
603620
return code
604621

605622
def c_code_cache_version(self):
623+
return None
606624
return (14, blas_header_version(), check_force_gemv_init())
607625

608626

tests/tensor/test_blas_c.py

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import numpy as np
44
import pytest
5+
from numpy.lib.stride_tricks import as_strided
56

67
import pytensor
78
import pytensor.tensor as pt
9+
from pytensor import config
810
from pytensor.tensor.basic import AllocEmpty
911
from pytensor.tensor.blas import Ger
1012
from pytensor.tensor.blas_c import CGemv, CGer, check_force_gemv_init
@@ -199,53 +201,77 @@ def test_force_gemv_init(self):
199201
" degradation in performance for such calls."
200202
)
201203

202-
def t_gemv1(self, m_shp):
203-
"""test vector2 + dot(matrix, vector1)"""
204+
@pytest.mark.skipif(config.blas__ldflags == "", reason="No blas")
205+
@pytest.mark.parametrize(
206+
"A_shape",
207+
[(3, 2), (1, 2), (0, 2), (3, 1), (3, 0), (1, 0), (1, 1), (0, 1), (0, 0)],
208+
ids=str,
209+
)
210+
@pytest.mark.parametrize("inplace", [True, False])
211+
def test_gemv1(self, A_shape, inplace: bool):
212+
"""test y + dot(A, x)"""
204213
rng = np.random.default_rng(unittest_tools.fetch_seed())
205-
v1 = pytensor.shared(np.array(rng.uniform(size=(m_shp[1],)), dtype="float32"))
206-
v2_orig = np.array(rng.uniform(size=(m_shp[0],)), dtype="float32")
207-
v2 = pytensor.shared(v2_orig)
208-
m = pytensor.shared(np.array(rng.uniform(size=m_shp), dtype="float32"))
209214

210-
f = pytensor.function([], v2 + pt.dot(m, v1), mode=self.mode)
211-
212-
# Assert they produce the same output
213-
assert np.allclose(f(), np.dot(m.get_value(), v1.get_value()) + v2_orig)
214-
topo = [n.op for n in f.maker.fgraph.toposort()]
215-
assert topo == [CGemv(inplace=False)], topo
216-
217-
# test the inplace version
218-
g = pytensor.function(
219-
[], [], updates=[(v2, v2 + pt.dot(m, v1))], mode=self.mode
215+
y = pt.vector("y", dtype="float32")
216+
x = pt.vector("x", dtype="float32")
217+
A = pt.matrix("A", dtype="float32")
218+
alpha = beta = 1.0
219+
220+
out = CGemv(inplace=inplace)(y, alpha, A, x, beta)
221+
f = pytensor.function([y, A, x], out, mode=self.mode, accept_inplace=inplace)
222+
f.dprint()
223+
assert [node.op for node in f.maker.fgraph.toposort()] == [
224+
CGemv(inplace=inplace)
225+
]
226+
227+
def assert_expected_output(inplace, f, y_test, A_test, x_test):
228+
# Copy y with the same strides as the original one
229+
y_test_copy = y_test.copy()
230+
y_test_copy = as_strided(
231+
y_test_copy, shape=y_test.shape, strides=y_test.strides
232+
)
233+
res = f(y_test_copy, A_test, x_test)
234+
if inplace:
235+
res = y_test_copy
236+
else:
237+
np.testing.assert_array_equal(y_test, y_test_copy)
238+
np.testing.assert_allclose(res, y_test + A_test @ x_test)
239+
240+
y_test = rng.uniform(size=A_shape[0]).astype("float32")
241+
A_test = rng.uniform(size=A_shape).astype("float32")
242+
x_test = rng.uniform(size=A_shape[1]).astype("float32")
243+
assert_expected_output(inplace, f, y_test, A_test, x_test)
244+
245+
## Fortran order
246+
y_test_fortran = np.asfortranarray(y_test)
247+
A_test_fortran = np.asfortranarray(A_test)
248+
x_test_fortran = np.asfortranarray(x_test)
249+
assert_expected_output(
250+
inplace, f, y_test_fortran, A_test_fortran, x_test_fortran
220251
)
221252

222-
# Assert they produce the same output
223-
g()
224-
assert np.allclose(
225-
v2.get_value(), np.dot(m.get_value(), v1.get_value()) + v2_orig
226-
)
227-
topo = [n.op for n in g.maker.fgraph.toposort()]
228-
assert topo == [CGemv(inplace=True)]
229-
230-
# Do the same tests with a matrix with strides in both dimensions
231-
m.set_value(m.get_value(borrow=True)[::-1, ::-1], borrow=True)
232-
v2.set_value(v2_orig)
233-
assert np.allclose(f(), np.dot(m.get_value(), v1.get_value()) + v2_orig)
234-
g()
235-
assert np.allclose(
236-
v2.get_value(), np.dot(m.get_value(), v1.get_value()) + v2_orig
253+
## Negative strides (or zero when size is zero)
254+
y_test_neg_strides = y_test[::-1]
255+
assert y_test_neg_strides.strides[0] in (-4, 0)
256+
A_test_neg_strides = A_test[::-1, ::-1]
257+
assert A_test_neg_strides.strides[1] in (-4, 0)
258+
x_test_neg_strides = x_test[::-1]
259+
assert x_test_neg_strides.strides[0] in (-4, 0)
260+
# assert_expected_output(inplace, f, y_test_neg_strides, A_test_neg_strides, x_test_neg_strides)
261+
262+
# Zero strides (by broadcasting)
263+
y_test_0_strides = np.broadcast_to(np.array(np.pi, dtype="float32"), A_shape[0])
264+
assert y_test_0_strides.strides == (0,)
265+
A_test_0_strides = np.broadcast_to(np.array(np.e, dtype="float32"), A_shape)
266+
assert A_test_0_strides.strides == (0, 0)
267+
x_test_0_strides = np.broadcast_to(
268+
np.array(np.euler_gamma, dtype="float32"), A_shape[1]
237269
)
238-
239-
def test_gemv1(self):
240-
skip_if_blas_ldflags_empty()
241-
self.t_gemv1((3, 2))
242-
self.t_gemv1((1, 2))
243-
self.t_gemv1((0, 2))
244-
self.t_gemv1((3, 1))
245-
self.t_gemv1((3, 0))
246-
self.t_gemv1((1, 0))
247-
self.t_gemv1((0, 1))
248-
self.t_gemv1((0, 0))
270+
assert x_test_0_strides.strides == (0,)
271+
# Test one input at a time so the outputs are unique
272+
assert_expected_output(inplace, f, y_test, A_test, x_test_0_strides)
273+
assert_expected_output(inplace, f, y_test, A_test_0_strides, x_test)
274+
# assert_expected_output(inplace, f, y_test_0_strides, A_test, x_test)
249275

250276
def test_gemv_dimensions(self, dtype="float32"):
251277
alpha = pytensor.shared(np.asarray(1.0, dtype=dtype), name="alpha")

0 commit comments

Comments
 (0)