Skip to content

Commit aeb35c5

Browse files
Remove jax tests
1 parent eeaf7dd commit aeb35c5

File tree

1 file changed

+11
-42
lines changed

1 file changed

+11
-42
lines changed

tests/link/jax/test_slinalg.py

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -104,33 +104,16 @@ def test_jax_basic():
104104
)
105105

106106

107-
@pytest.mark.parametrize(
108-
"b_shape",
109-
[(5, 1), (5, 5), (5,)],
110-
ids=["b_col_vec", "b_matrix", "b_vec"],
111-
)
112-
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
113-
@pytest.mark.parametrize("lower", [False, True])
114-
@pytest.mark.parametrize("transposed", [False, True])
115-
def test_jax_solve(b_shape: tuple[int], assume_a, lower, transposed):
107+
def test_jax_solve():
116108
rng = np.random.default_rng(utt.fetch_seed())
117109

118110
A = pt.tensor("A", shape=(5, 5))
119-
b = pt.tensor("B", shape=b_shape)
120-
121-
def A_func(x):
122-
if assume_a == "sym":
123-
return (x + x.T) / 2
124-
if assume_a == "pos":
125-
return x @ x.T
126-
return x
111+
b = pt.tensor("B", shape=(5, 5))
127112

128-
out = pt_slinalg.solve(
129-
A_func(A), b, assume_a=assume_a, lower=lower, transposed=transposed
130-
)
113+
out = pt_slinalg.solve(A, b, lower=False, transposed=False)
131114

132115
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
133-
b_val = rng.normal(size=b_shape).astype(config.floatX)
116+
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
134117

135118
compare_jax_and_py(
136119
[A, b],
@@ -139,35 +122,21 @@ def A_func(x):
139122
)
140123

141124

142-
@pytest.mark.parametrize(
143-
"b_shape", [(5, 1), (5, 5), (5,)], ids=["b_col_vec", "b_matrix", "b_vec"]
144-
)
145-
@pytest.mark.parametrize("lower", [False, True])
146-
@pytest.mark.parametrize("trans", [0, 1, 2])
147-
@pytest.mark.parametrize("unit_diagonal", [False, True])
148-
def test_jax_SolveTriangular(b_shape: tuple[int], lower, trans, unit_diagonal):
125+
def test_jax_SolveTriangular():
149126
rng = np.random.default_rng(utt.fetch_seed())
150127

151128
A = pt.tensor("A", shape=(5, 5))
152-
b = pt.tensor("B", shape=b_shape)
153-
154-
def A_func(x):
155-
x = x @ x.T
156-
x = pt.linalg.cholesky(x, lower=lower)
157-
if unit_diagonal:
158-
x = pt.fill_diagonal(x, 1.0)
159-
160-
return x
129+
b = pt.tensor("B", shape=(5, 5))
161130

162131
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
163-
b_val = rng.normal(size=b_shape).astype(config.floatX)
132+
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
164133

165134
out = pt_slinalg.solve_triangular(
166-
A_func(A),
135+
A,
167136
b,
168-
trans=trans,
169-
lower=lower,
170-
unit_diagonal=unit_diagonal,
137+
trans=0,
138+
lower=True,
139+
unit_diagonal=False,
171140
)
172141
compare_jax_and_py([A, b], [out], [A_val, b_val])
173142

0 commit comments

Comments
 (0)