File tree Expand file tree Collapse file tree 1 file changed +32
-0
lines changed Expand file tree Collapse file tree 1 file changed +32
-0
lines changed Original file line number Diff line number Diff line change @@ -122,6 +122,38 @@ def test_jax_solve():
122
122
)
123
123
124
124
125
+ def test_jax_tridiagonal_solve ():
126
+ N = 10
127
+ A = pt .matrix ("A" , shape = (N , N ))
128
+ b = pt .vector ("b" , shape = (N ,))
129
+
130
+ out = pt .linalg .solve (A , b , assume_a = "tridiagonal" )
131
+
132
+ A_val = np .eye (N )
133
+ for i in range (N - 1 ):
134
+ A_val [i , i + 1 ] = np .random .randn ()
135
+ A_val [i + 1 , i ] = np .random .randn ()
136
+
137
+ b_val = np .random .randn (N )
138
+
139
+ compare_jax_and_py (
140
+ [A , b ],
141
+ [out ],
142
+ [A_val , b_val ],
143
+ )
144
+
145
+ b_ = pt .matrix ("b" , shape = (N , 2 ))
146
+
147
+ out = pt .linalg .solve (A , b_ , assume_a = "tridiagonal" )
148
+ b_val = np .random .randn (N , 2 )
149
+
150
+ compare_jax_and_py (
151
+ [A , b_ ],
152
+ [out ],
153
+ [A_val , b_val ],
154
+ )
155
+
156
+
125
157
def test_jax_SolveTriangular ():
126
158
rng = np .random .default_rng (utt .fetch_seed ())
127
159
You can’t perform that action at this time.
0 commit comments