|
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import pytest |
| 5 | +from numpy.lib.stride_tricks import as_strided |
5 | 6 |
|
6 | 7 | import pytensor |
7 | 8 | import pytensor.tensor as pt |
| 9 | +from pytensor import config |
8 | 10 | from pytensor.tensor.basic import AllocEmpty |
9 | 11 | from pytensor.tensor.blas import Ger |
10 | 12 | from pytensor.tensor.blas_c import CGemv, CGer, check_force_gemv_init |
@@ -199,53 +201,77 @@ def test_force_gemv_init(self): |
199 | 201 | " degradation in performance for such calls." |
200 | 202 | ) |
201 | 203 |
|
202 | | - def t_gemv1(self, m_shp): |
203 | | - """test vector2 + dot(matrix, vector1)""" |
| 204 | + @pytest.mark.skipif(config.blas__ldflags == "", reason="No blas") |
| 205 | + @pytest.mark.parametrize( |
| 206 | + "A_shape", |
| 207 | + [(3, 2), (1, 2), (0, 2), (3, 1), (3, 0), (1, 0), (1, 1), (0, 1), (0, 0)], |
| 208 | + ids=str, |
| 209 | + ) |
| 210 | + @pytest.mark.parametrize("inplace", [True, False]) |
| 211 | + def test_gemv1(self, A_shape, inplace: bool): |
| 212 | + """test y + dot(A, x)""" |
204 | 213 | rng = np.random.default_rng(unittest_tools.fetch_seed()) |
205 | | - v1 = pytensor.shared(np.array(rng.uniform(size=(m_shp[1],)), dtype="float32")) |
206 | | - v2_orig = np.array(rng.uniform(size=(m_shp[0],)), dtype="float32") |
207 | | - v2 = pytensor.shared(v2_orig) |
208 | | - m = pytensor.shared(np.array(rng.uniform(size=m_shp), dtype="float32")) |
209 | 214 |
|
210 | | - f = pytensor.function([], v2 + pt.dot(m, v1), mode=self.mode) |
211 | | - |
212 | | - # Assert they produce the same output |
213 | | - assert np.allclose(f(), np.dot(m.get_value(), v1.get_value()) + v2_orig) |
214 | | - topo = [n.op for n in f.maker.fgraph.toposort()] |
215 | | - assert topo == [CGemv(inplace=False)], topo |
216 | | - |
217 | | - # test the inplace version |
218 | | - g = pytensor.function( |
219 | | - [], [], updates=[(v2, v2 + pt.dot(m, v1))], mode=self.mode |
| 215 | + y = pt.vector("y", dtype="float32") |
| 216 | + x = pt.vector("x", dtype="float32") |
| 217 | + A = pt.matrix("A", dtype="float32") |
| 218 | + alpha = beta = 1.0 |
| 219 | + |
| 220 | + out = CGemv(inplace=inplace)(y, alpha, A, x, beta) |
| 221 | + f = pytensor.function([y, A, x], out, mode=self.mode, accept_inplace=inplace) |
| 222 | + f.dprint() |
| 223 | + assert [node.op for node in f.maker.fgraph.toposort()] == [ |
| 224 | + CGemv(inplace=inplace) |
| 225 | + ] |
| 226 | + |
| 227 | + def assert_expected_output(inplace, f, y_test, A_test, x_test): |
| 228 | + # Copy y with the same strides as the original one |
| 229 | + y_test_copy = y_test.copy() |
| 230 | + y_test_copy = as_strided( |
| 231 | + y_test_copy, shape=y_test.shape, strides=y_test.strides |
| 232 | + ) |
| 233 | + res = f(y_test_copy, A_test, x_test) |
| 234 | + if inplace: |
| 235 | + res = y_test_copy |
| 236 | + else: |
| 237 | + np.testing.assert_array_equal(y_test, y_test_copy) |
| 238 | + np.testing.assert_allclose(res, y_test + A_test @ x_test) |
| 239 | + |
| 240 | + y_test = rng.uniform(size=A_shape[0]).astype("float32") |
| 241 | + A_test = rng.uniform(size=A_shape).astype("float32") |
| 242 | + x_test = rng.uniform(size=A_shape[1]).astype("float32") |
| 243 | + assert_expected_output(inplace, f, y_test, A_test, x_test) |
| 244 | + |
| 245 | + ## Fortran order |
| 246 | + y_test_fortran = np.asfortranarray(y_test) |
| 247 | + A_test_fortran = np.asfortranarray(A_test) |
| 248 | + x_test_fortran = np.asfortranarray(x_test) |
| 249 | + assert_expected_output( |
| 250 | + inplace, f, y_test_fortran, A_test_fortran, x_test_fortran |
220 | 251 | ) |
221 | 252 |
|
222 | | - # Assert they produce the same output |
223 | | - g() |
224 | | - assert np.allclose( |
225 | | - v2.get_value(), np.dot(m.get_value(), v1.get_value()) + v2_orig |
226 | | - ) |
227 | | - topo = [n.op for n in g.maker.fgraph.toposort()] |
228 | | - assert topo == [CGemv(inplace=True)] |
229 | | - |
230 | | - # Do the same tests with a matrix with strides in both dimensions |
231 | | - m.set_value(m.get_value(borrow=True)[::-1, ::-1], borrow=True) |
232 | | - v2.set_value(v2_orig) |
233 | | - assert np.allclose(f(), np.dot(m.get_value(), v1.get_value()) + v2_orig) |
234 | | - g() |
235 | | - assert np.allclose( |
236 | | - v2.get_value(), np.dot(m.get_value(), v1.get_value()) + v2_orig |
| 253 | + ## Negative strides (or zero when size is zero) |
| 254 | + y_test_neg_strides = y_test[::-1] |
| 255 | + assert y_test_neg_strides.strides[0] in (-4, 0) |
| 256 | + A_test_neg_strides = A_test[::-1, ::-1] |
| 257 | + assert A_test_neg_strides.strides[1] in (-4, 0) |
| 258 | + x_test_neg_strides = x_test[::-1] |
| 259 | + assert x_test_neg_strides.strides[0] in (-4, 0) |
| 260 | + # assert_expected_output(inplace, f, y_test_neg_strides, A_test_neg_strides, x_test_neg_strides) |
| 261 | + |
| 262 | + # Zero strides (by broadcasting) |
| 263 | + y_test_0_strides = np.broadcast_to(np.array(np.pi, dtype="float32"), A_shape[0]) |
| 264 | + assert y_test_0_strides.strides == (0,) |
| 265 | + A_test_0_strides = np.broadcast_to(np.array(np.e, dtype="float32"), A_shape) |
| 266 | + assert A_test_0_strides.strides == (0, 0) |
| 267 | + x_test_0_strides = np.broadcast_to( |
| 268 | + np.array(np.euler_gamma, dtype="float32"), A_shape[1] |
237 | 269 | ) |
238 | | - |
239 | | - def test_gemv1(self): |
240 | | - skip_if_blas_ldflags_empty() |
241 | | - self.t_gemv1((3, 2)) |
242 | | - self.t_gemv1((1, 2)) |
243 | | - self.t_gemv1((0, 2)) |
244 | | - self.t_gemv1((3, 1)) |
245 | | - self.t_gemv1((3, 0)) |
246 | | - self.t_gemv1((1, 0)) |
247 | | - self.t_gemv1((0, 1)) |
248 | | - self.t_gemv1((0, 0)) |
| 270 | + assert x_test_0_strides.strides == (0,) |
| 271 | + # Test one input at a time so the outputs are unique |
| 272 | + assert_expected_output(inplace, f, y_test, A_test, x_test_0_strides) |
| 273 | + assert_expected_output(inplace, f, y_test, A_test_0_strides, x_test) |
| 274 | + # assert_expected_output(inplace, f, y_test_0_strides, A_test, x_test) |
249 | 275 |
|
250 | 276 | def test_gemv_dimensions(self, dtype="float32"): |
251 | 277 | alpha = pytensor.shared(np.asarray(1.0, dtype=dtype), name="alpha") |
|
0 commit comments