Skip to content

Commit 4db2a33

Browse files
float32 compat in tests
1 parent c687856 commit 4db2a33

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/tensor/test_slinalg.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,11 +1073,11 @@ def _make_banded_A(A, kl, ku):
10731073
def test_banded_dot(A_shape, kl, ku):
10741074
rng = np.random.default_rng()
10751075

1076-
A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku)
1077-
b_val = rng.normal(size=(A_shape[-1],))
1076+
A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku).astype(config.floatX)
1077+
b_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX)
10781078

1079-
A = pt.tensor("A", shape=A_val.shape)
1080-
b = pt.tensor("b", shape=b_val.shape)
1079+
A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
1080+
b = pt.tensor("b", shape=b_val.shape, dtype=b_val.dtype)
10811081
res = banded_dot(A, b, kl, ku)
10821082
res_2 = A @ b
10831083

@@ -1098,11 +1098,11 @@ def test_banded_dot(A_shape, kl, ku):
10981098
def test_banded_dot_perf(op, A_shape, benchmark):
10991099
rng = np.random.default_rng()
11001100

1101-
A_val = _make_banded_A(rng.normal(size=A_shape), kl=1, ku=1)
1102-
b_val = rng.normal(size=(A_shape[-1],))
1101+
A_val = _make_banded_A(rng.normal(size=A_shape), kl=1, ku=1).astype(config.floatX)
1102+
b_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX)
11031103

1104-
A = pt.tensor("A", shape=A_val.shape)
1105-
b = pt.tensor("b", shape=b_val.shape)
1104+
A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
1105+
b = pt.tensor("b", shape=b_val.shape, dtype=A_val.dtype)
11061106

11071107
if op == "dot":
11081108
f = pt.dot

0 commit comments

Comments
 (0)