We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 12213d0 commit f0fb328Copy full SHA for f0fb328
tests/tensor/test_blas_c.py
@@ -486,3 +486,26 @@ def test_gemv_negative_strides_perf(neg_stride0, neg_stride1, F_layout, benchmar
486
np.testing.assert_allclose(res, fn(test_A.copy(), test_x, test_y))
487
488
benchmark(fn, test_A, test_x, test_y)
489
+
490
491
+@pytest.mark.parametrize("inplace", (True, False), ids=["inplace", "no_inplace"])
492
+@pytest.mark.parametrize("n", [2**7, 2**9, 2**13])
493
+def test_ger_benchmark(n, inplace, benchmark):
494
+ alpha = pt.dscalar("alpha")
495
+ x = pt.dvector("x")
496
+ y = pt.dvector("y")
497
+ A = pt.dmatrix("A")
498
499
+ out = alpha * pt.outer(x, y) + A
500
501
+ fn = pytensor.function(
502
+ [alpha, x, y, pytensor.In(A, mutable=inplace)], out, trust_input=True
503
+ )
504
505
+ rng = np.random.default_rng([2274, n])
506
+ alpha_test = rng.normal(size=())
507
+ x_test = rng.normal(size=(n,))
508
+ y_test = rng.normal(size=(n,))
509
+ A_test = rng.normal(size=(n, n))
510
511
+ benchmark(fn, alpha_test, x_test, y_test, A_test)
0 commit comments