Skip to content

Commit 501ae60

Browse files
benmaierjessegrabowski
authored andcommitted
added tests for tridiagonal solve
1 parent 032ffa2 commit 501ae60

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/link/jax/test_slinalg.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,38 @@ def test_jax_solve():
122122
)
123123

124124

125+
def test_jax_tridiagonal_solve():
126+
N = 10
127+
A = pt.matrix("A", shape=(N, N))
128+
b = pt.vector("b", shape=(N,))
129+
130+
out = pt.linalg.solve(A, b, assume_a="tridiagonal")
131+
132+
A_val = np.eye(N)
133+
for i in range(N - 1):
134+
A_val[i, i + 1] = np.random.randn()
135+
A_val[i + 1, i] = np.random.randn()
136+
137+
b_val = np.random.randn(N)
138+
139+
compare_jax_and_py(
140+
[A, b],
141+
[out],
142+
[A_val, b_val],
143+
)
144+
145+
b_ = pt.matrix("b", shape=(N, 2))
146+
147+
out = pt.linalg.solve(A, b_, assume_a="tridiagonal")
148+
b_val = np.random.randn(N, 2)
149+
150+
compare_jax_and_py(
151+
[A, b_],
152+
[out],
153+
[A_val, b_val],
154+
)
155+
156+
125157
def test_jax_SolveTriangular():
126158
rng = np.random.default_rng(utt.fetch_seed())
127159

0 commit comments

Comments
 (0)