Skip to content

Commit 14d2dcd

Browse files
float32 slinalg
1 parent 8e8311e commit 14d2dcd

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/tensor/test_slinalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,10 @@ def test_lu_solve(self, b_shape: tuple[int], trans):
698698
A = pt.tensor("A", shape=(5, 5))
699699
b = pt.tensor("b", shape=b_shape)
700700

701-
A_val = rng.normal(size=(5, 5)).astype(config.floatX) + np.eye(5) * 0.5
701+
A_val = (
702+
rng.normal(size=(5, 5)).astype(config.floatX)
703+
+ np.eye(5, dtype=config.floatX) * 0.5
704+
)
702705
b_val = rng.normal(size=b_shape).astype(config.floatX)
703706

704707
x = self.factor_and_solve(A, b, trans=trans, sum=False)

0 commit comments

Comments
 (0)