@@ -104,33 +104,16 @@ def test_jax_basic():
104104 )
105105
106106
107- @pytest .mark .parametrize (
108- "b_shape" ,
109- [(5 , 1 ), (5 , 5 ), (5 ,)],
110- ids = ["b_col_vec" , "b_matrix" , "b_vec" ],
111- )
112- @pytest .mark .parametrize ("assume_a" , ["gen" , "sym" , "pos" ], ids = str )
113- @pytest .mark .parametrize ("lower" , [False , True ])
114- @pytest .mark .parametrize ("transposed" , [False , True ])
115- def test_jax_solve (b_shape : tuple [int ], assume_a , lower , transposed ):
107+ def test_jax_solve ():
116108 rng = np .random .default_rng (utt .fetch_seed ())
117109
118110 A = pt .tensor ("A" , shape = (5 , 5 ))
119- b = pt .tensor ("B" , shape = b_shape )
120-
121- def A_func (x ):
122- if assume_a == "sym" :
123- return (x + x .T ) / 2
124- if assume_a == "pos" :
125- return x @ x .T
126- return x
111+ b = pt .tensor ("B" , shape = (5 , 5 ))
127112
128- out = pt_slinalg .solve (
129- A_func (A ), b , assume_a = assume_a , lower = lower , transposed = transposed
130- )
113+ out = pt_slinalg .solve (A , b , lower = False , transposed = False )
131114
132115 A_val = rng .normal (size = (5 , 5 )).astype (config .floatX )
133- b_val = rng .normal (size = b_shape ).astype (config .floatX )
116+ b_val = rng .normal (size = ( 5 , 5 ) ).astype (config .floatX )
134117
135118 compare_jax_and_py (
136119 [A , b ],
@@ -139,35 +122,21 @@ def A_func(x):
139122 )
140123
141124
142- @pytest .mark .parametrize (
143- "b_shape" , [(5 , 1 ), (5 , 5 ), (5 ,)], ids = ["b_col_vec" , "b_matrix" , "b_vec" ]
144- )
145- @pytest .mark .parametrize ("lower" , [False , True ])
146- @pytest .mark .parametrize ("trans" , [0 , 1 , 2 ])
147- @pytest .mark .parametrize ("unit_diagonal" , [False , True ])
148- def test_jax_SolveTriangular (b_shape : tuple [int ], lower , trans , unit_diagonal ):
125+ def test_jax_SolveTriangular ():
149126 rng = np .random .default_rng (utt .fetch_seed ())
150127
151128 A = pt .tensor ("A" , shape = (5 , 5 ))
152- b = pt .tensor ("B" , shape = b_shape )
153-
154- def A_func (x ):
155- x = x @ x .T
156- x = pt .linalg .cholesky (x , lower = lower )
157- if unit_diagonal :
158- x = pt .fill_diagonal (x , 1.0 )
159-
160- return x
129+ b = pt .tensor ("B" , shape = (5 , 5 ))
161130
162131 A_val = rng .normal (size = (5 , 5 )).astype (config .floatX )
163- b_val = rng .normal (size = b_shape ).astype (config .floatX )
132+ b_val = rng .normal (size = ( 5 , 5 ) ).astype (config .floatX )
164133
165134 out = pt_slinalg .solve_triangular (
166- A_func ( A ) ,
135+ A ,
167136 b ,
168- trans = trans ,
169- lower = lower ,
170- unit_diagonal = unit_diagonal ,
137+ trans = 0 ,
138+ lower = True ,
139+ unit_diagonal = False ,
171140 )
172141 compare_jax_and_py ([A , b ], [out ], [A_val , b_val ])
173142
0 commit comments