@@ -2135,6 +2135,215 @@ def test_check_finite_raises(self):
2135
2135
assert_raises (ValueError , dpnp .linalg .lu_factor , a , check_finite = True )
2136
2136
2137
2137
2138
+ class TestLuSolve :
2139
+ @staticmethod
2140
+ def _make_nonsingular_np (shape , dtype , order ):
2141
+ A = generate_random_numpy_array (shape , dtype , order )
2142
+ m , n = shape
2143
+ k = min (m , n )
2144
+ for i in range (k ):
2145
+ off = numpy .sum (numpy .abs (A [i , :n ])) - numpy .abs (A [i , i ])
2146
+ A [i , i ] = A .dtype .type (off + 1.0 )
2147
+ return A
2148
+
2149
+ @pytest .mark .parametrize ("shape" , [(1 , 1 ), (2 , 2 ), (3 , 3 ), (5 , 5 )])
2150
+ @pytest .mark .parametrize ("rhs_cols" , [None , 1 , 3 ])
2151
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2152
+ @pytest .mark .parametrize (
2153
+ "dtype" , get_all_dtypes (no_bool = True , no_none = True )
2154
+ )
2155
+ def test_lu_solve (self , shape , rhs_cols , order , dtype ):
2156
+ a_np = self ._make_nonsingular_np (shape , dtype , order )
2157
+ a_dp = dpnp .array (a_np , order = order )
2158
+
2159
+ n = shape [0 ]
2160
+ if rhs_cols is None :
2161
+ b_np = generate_random_numpy_array ((n ,), dtype , order )
2162
+ else :
2163
+ b_np = generate_random_numpy_array ((n , rhs_cols ), dtype , order )
2164
+ b_dp = dpnp .array (b_np , order = order )
2165
+
2166
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2167
+ x = dpnp .linalg .lu_solve (
2168
+ (lu , piv ), b_dp , trans = 0 , overwrite_b = False , check_finite = False
2169
+ )
2170
+
2171
+ # check A @ x = b
2172
+ Ax = a_dp @ x
2173
+ assert dpnp .allclose (Ax , b_dp , rtol = 1e-6 , atol = 1e-6 )
2174
+
2175
+ @pytest .mark .parametrize ("trans" , [0 , 1 , 2 ])
2176
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2177
+ def test_trans (self , trans , dtype ):
2178
+ n = 4
2179
+ a_np = self ._make_nonsingular_np ((n , n ), dtype , order = "F" )
2180
+ a_dp = dpnp .array (a_np , order = "F" )
2181
+ b_dp = dpnp .array (generate_random_numpy_array ((n , 2 ), dtype , "F" ))
2182
+
2183
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2184
+ x = dpnp .linalg .lu_solve (
2185
+ (lu , piv ), b_dp , trans = trans , overwrite_b = False , check_finite = False
2186
+ )
2187
+
2188
+ if trans == 0 :
2189
+ lhs = a_dp @ x
2190
+ elif trans == 1 :
2191
+ lhs = a_dp .T @ x
2192
+ else : # trans == 2
2193
+ lhs = a_dp .conj ().T @ x
2194
+
2195
+ assert dpnp .allclose (lhs , b_dp , rtol = 1e-6 , atol = 1e-6 )
2196
+
2197
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2198
+ def test_overwrite_inplace (self , dtype ):
2199
+ a_dp = dpnp .array ([[4 , 3 ], [6 , 3 ]], dtype = dtype , order = "F" )
2200
+ b_dp = dpnp .array ([1 , 0 ], dtype = dtype , order = "F" )
2201
+ b_orig = b_dp .copy ()
2202
+
2203
+ lu , piv = dpnp .linalg .lu_factor (
2204
+ a_dp , overwrite_a = False , check_finite = False
2205
+ )
2206
+ x = dpnp .linalg .lu_solve (
2207
+ (lu , piv ), b_dp , trans = 0 , overwrite_b = True , check_finite = False
2208
+ )
2209
+
2210
+ assert x is b_dp
2211
+ assert dpnp .allclose (a_dp @ x , b_orig , rtol = 1e-6 , atol = 1e-6 )
2212
+
2213
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2214
+ def test_overwrite_copy_special (self , dtype ):
2215
+ a_dp = dpnp .array ([[4 , 3 ], [6 , 3 ]], dtype = dtype , order = "F" )
2216
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2217
+
2218
+ # F-contig but dtype != res_type
2219
+ b1 = dpnp .array ([1 , 0 ], dtype = dpnp .int32 , order = "F" )
2220
+ x1 = dpnp .linalg .lu_solve (
2221
+ (lu , piv ), b1 , overwrite_b = True , check_finite = False
2222
+ )
2223
+ assert x1 is not b1
2224
+
2225
+ # F-contig, match dtype but read-only input
2226
+ b2 = dpnp .array ([1 , 0 ], dtype = dtype , order = "F" )
2227
+ b2 .flags ["WRITABLE" ] = False
2228
+ x2 = dpnp .linalg .lu_solve (
2229
+ (lu , piv ), b2 , overwrite_b = True , check_finite = False
2230
+ )
2231
+ assert x2 is not b2
2232
+
2233
+ for x in (x1 , x2 ):
2234
+ assert dpnp .allclose (
2235
+ a_dp @ x ,
2236
+ dpnp .array ([1 , 0 ], dtype = x .dtype ),
2237
+ rtol = 1e-6 ,
2238
+ atol = 1e-6 ,
2239
+ )
2240
+
2241
+ @pytest .mark .parametrize (
2242
+ "dtype_a" , get_all_dtypes (no_bool = True , no_none = True )
2243
+ )
2244
+ @pytest .mark .parametrize (
2245
+ "dtype_b" , get_all_dtypes (no_bool = True , no_none = True )
2246
+ )
2247
+ def test_diff_type (self , dtype_a , dtype_b ):
2248
+ a_np = self ._make_nonsingular_np ((3 , 3 ), dtype_a , order = "F" )
2249
+ a_dp = dpnp .array (a_np , order = "F" )
2250
+
2251
+ b_np = generate_random_numpy_array ((3 ,), dtype_b , order = "F" )
2252
+ b_dp = dpnp .array (b_np , order = "F" )
2253
+
2254
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2255
+ x = dpnp .linalg .lu_solve ((lu , piv ), b_dp , check_finite = False )
2256
+ assert dpnp .allclose (
2257
+ a_dp @ x , b_dp .astype (x .dtype , copy = False ), rtol = 1e-6 , atol = 1e-6
2258
+ )
2259
+
2260
+ def test_strided_rhs (self ):
2261
+ n = 7
2262
+ a_np = self ._make_nonsingular_np (
2263
+ (n , n ), dpnp .default_float_type (), order = "F"
2264
+ )
2265
+ a_dp = dpnp .array (a_np , order = "F" )
2266
+
2267
+ rhs_full = (
2268
+ dpnp .arange (n * n , dtype = dpnp .default_float_type ()).reshape (
2269
+ n , n , order = "F"
2270
+ )
2271
+ + 1.0
2272
+ )
2273
+ b_dp = rhs_full [:, ::2 ][:, :3 ]
2274
+
2275
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2276
+ x = dpnp .linalg .lu_solve (
2277
+ (lu , piv ), b_dp , overwrite_b = False , check_finite = False
2278
+ )
2279
+
2280
+ assert dpnp .allclose (a_dp @ x , b_dp , rtol = 1e-6 , atol = 1e-6 )
2281
+
2282
+ @pytest .mark .skip ("Not implemented yet" )
2283
+ @pytest .mark .parametrize (
2284
+ "b_shape" ,
2285
+ [
2286
+ (4 ,),
2287
+ (4 , 1 ),
2288
+ (4 , 3 ),
2289
+ # (1, 4, 3),
2290
+ # (2, 4, 3),
2291
+ # (1, 1, 4, 3)
2292
+ ],
2293
+ )
2294
+ def test_broadcast_rhs (self , b_shape ):
2295
+ dtype = dpnp .default_float_type ()
2296
+
2297
+ a_np = self ._make_nonsingular_np ((4 , 4 ), dtype , order = "F" )
2298
+ a_dp = dpnp .array (a_np , order = "F" )
2299
+
2300
+ b_np = generate_random_numpy_array (b_shape , dtype , order = "F" )
2301
+ b_dp = dpnp .array (b_np , order = "F" )
2302
+
2303
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2304
+ x = dpnp .linalg .lu_solve (
2305
+ (lu , piv ), b_dp , overwrite_b = True , check_finite = False
2306
+ )
2307
+
2308
+ assert x .shape == b_dp .shape
2309
+
2310
+ assert dpnp .allclose (a_dp @ x , b_dp , rtol = 1e-6 , atol = 1e-6 )
2311
+
2312
+ @pytest .mark .parametrize ("shape" , [(0 , 0 ), (0 , 5 ), (5 , 5 )])
2313
+ @pytest .mark .parametrize ("rhs_cols" , [None , 0 , 3 ])
2314
+ def test_empty_shapes (self , shape , rhs_cols ):
2315
+ a_dp = dpnp .empty (shape , dtype = dpnp .default_float_type (), order = "F" )
2316
+ if min (shape ) > 0 :
2317
+ for i in range (min (shape )):
2318
+ a_dp [i , i ] = a_dp .dtype .type (1.0 )
2319
+
2320
+ n = shape [0 ]
2321
+ if rhs_cols is None :
2322
+ b_shape = (n ,)
2323
+ else :
2324
+ b_shape = (n , rhs_cols )
2325
+ b_dp = dpnp .empty (b_shape , dtype = dpnp .default_float_type (), order = "F" )
2326
+
2327
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2328
+ x = dpnp .linalg .lu_solve ((lu , piv ), b_dp , check_finite = False )
2329
+
2330
+ assert x .shape == b_shape
2331
+
2332
+ @pytest .mark .parametrize ("bad" , [numpy .inf , - numpy .inf , numpy .nan ])
2333
+ def test_check_finite_raises (self , bad ):
2334
+ a_dp = dpnp .array ([[1.0 , 0.0 ], [0.0 , 1.0 ]], order = "F" )
2335
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2336
+
2337
+ b_bad = dpnp .array ([1.0 , bad ], order = "F" )
2338
+ assert_raises (
2339
+ ValueError ,
2340
+ dpnp .linalg .lu_solve ,
2341
+ (lu , piv ),
2342
+ b_bad ,
2343
+ check_finite = True ,
2344
+ )
2345
+
2346
+
2138
2347
class TestMatrixPower :
2139
2348
@pytest .mark .parametrize ("dtype" , get_all_dtypes ())
2140
2349
@pytest .mark .parametrize (
0 commit comments