Skip to content

Commit f0fb328

Browse files
committed
Benchmark GER graph
1 parent 12213d0 commit f0fb328

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/tensor/test_blas_c.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,26 @@ def test_gemv_negative_strides_perf(neg_stride0, neg_stride1, F_layout, benchmar
486486
np.testing.assert_allclose(res, fn(test_A.copy(), test_x, test_y))
487487

488488
benchmark(fn, test_A, test_x, test_y)
489+
490+
491+
@pytest.mark.parametrize("inplace", (True, False), ids=["inplace", "no_inplace"])
492+
@pytest.mark.parametrize("n", [2**7, 2**9, 2**13])
493+
def test_ger_benchmark(n, inplace, benchmark):
494+
alpha = pt.dscalar("alpha")
495+
x = pt.dvector("x")
496+
y = pt.dvector("y")
497+
A = pt.dmatrix("A")
498+
499+
out = alpha * pt.outer(x, y) + A
500+
501+
fn = pytensor.function(
502+
[alpha, x, y, pytensor.In(A, mutable=inplace)], out, trust_input=True
503+
)
504+
505+
rng = np.random.default_rng([2274, n])
506+
alpha_test = rng.normal(size=())
507+
x_test = rng.normal(size=(n,))
508+
y_test = rng.normal(size=(n,))
509+
A_test = rng.normal(size=(n, n))
510+
511+
benchmark(fn, alpha_test, x_test, y_test, A_test)

0 commit comments

Comments
 (0)