@@ -2144,6 +2144,215 @@ def test_check_finite_raises(self):
2144
2144
assert_raises (ValueError , dpnp .linalg .lu_factor , a , check_finite = True )
2145
2145
2146
2146
2147
+ class TestLuSolve :
2148
+ @staticmethod
2149
+ def _make_nonsingular_np (shape , dtype , order ):
2150
+ A = generate_random_numpy_array (shape , dtype , order )
2151
+ m , n = shape
2152
+ k = min (m , n )
2153
+ for i in range (k ):
2154
+ off = numpy .sum (numpy .abs (A [i , :n ])) - numpy .abs (A [i , i ])
2155
+ A [i , i ] = A .dtype .type (off + 1.0 )
2156
+ return A
2157
+
2158
+ @pytest .mark .parametrize ("shape" , [(1 , 1 ), (2 , 2 ), (3 , 3 ), (5 , 5 )])
2159
+ @pytest .mark .parametrize ("rhs_cols" , [None , 1 , 3 ])
2160
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2161
+ @pytest .mark .parametrize (
2162
+ "dtype" , get_all_dtypes (no_bool = True , no_none = True )
2163
+ )
2164
+ def test_lu_solve (self , shape , rhs_cols , order , dtype ):
2165
+ a_np = self ._make_nonsingular_np (shape , dtype , order )
2166
+ a_dp = dpnp .array (a_np , order = order )
2167
+
2168
+ n = shape [0 ]
2169
+ if rhs_cols is None :
2170
+ b_np = generate_random_numpy_array ((n ,), dtype , order )
2171
+ else :
2172
+ b_np = generate_random_numpy_array ((n , rhs_cols ), dtype , order )
2173
+ b_dp = dpnp .array (b_np , order = order )
2174
+
2175
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2176
+ x = dpnp .linalg .lu_solve (
2177
+ (lu , piv ), b_dp , trans = 0 , overwrite_b = False , check_finite = False
2178
+ )
2179
+
2180
+ # check A @ x = b
2181
+ Ax = a_dp @ x
2182
+ assert dpnp .allclose (Ax , b_dp , rtol = 1e-6 , atol = 1e-6 )
2183
+
2184
+ @pytest .mark .parametrize ("trans" , [0 , 1 , 2 ])
2185
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2186
+ def test_trans (self , trans , dtype ):
2187
+ n = 4
2188
+ a_np = self ._make_nonsingular_np ((n , n ), dtype , order = "F" )
2189
+ a_dp = dpnp .array (a_np , order = "F" )
2190
+ b_dp = dpnp .array (generate_random_numpy_array ((n , 2 ), dtype , "F" ))
2191
+
2192
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2193
+ x = dpnp .linalg .lu_solve (
2194
+ (lu , piv ), b_dp , trans = trans , overwrite_b = False , check_finite = False
2195
+ )
2196
+
2197
+ if trans == 0 :
2198
+ lhs = a_dp @ x
2199
+ elif trans == 1 :
2200
+ lhs = a_dp .T @ x
2201
+ else : # trans == 2
2202
+ lhs = a_dp .conj ().T @ x
2203
+
2204
+ assert dpnp .allclose (lhs , b_dp , rtol = 1e-6 , atol = 1e-6 )
2205
+
2206
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2207
+ def test_overwrite_inplace (self , dtype ):
2208
+ a_dp = dpnp .array ([[4 , 3 ], [6 , 3 ]], dtype = dtype , order = "F" )
2209
+ b_dp = dpnp .array ([1 , 0 ], dtype = dtype , order = "F" )
2210
+ b_orig = b_dp .copy ()
2211
+
2212
+ lu , piv = dpnp .linalg .lu_factor (
2213
+ a_dp , overwrite_a = False , check_finite = False
2214
+ )
2215
+ x = dpnp .linalg .lu_solve (
2216
+ (lu , piv ), b_dp , trans = 0 , overwrite_b = True , check_finite = False
2217
+ )
2218
+
2219
+ assert x is b_dp
2220
+ assert dpnp .allclose (a_dp @ x , b_orig , rtol = 1e-6 , atol = 1e-6 )
2221
+
2222
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2223
+ def test_overwrite_copy_special (self , dtype ):
2224
+ a_dp = dpnp .array ([[4 , 3 ], [6 , 3 ]], dtype = dtype , order = "F" )
2225
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2226
+
2227
+ # F-contig but dtype != res_type
2228
+ b1 = dpnp .array ([1 , 0 ], dtype = dpnp .int32 , order = "F" )
2229
+ x1 = dpnp .linalg .lu_solve (
2230
+ (lu , piv ), b1 , overwrite_b = True , check_finite = False
2231
+ )
2232
+ assert x1 is not b1
2233
+
2234
+ # F-contig, match dtype but read-only input
2235
+ b2 = dpnp .array ([1 , 0 ], dtype = dtype , order = "F" )
2236
+ b2 .flags ["WRITABLE" ] = False
2237
+ x2 = dpnp .linalg .lu_solve (
2238
+ (lu , piv ), b2 , overwrite_b = True , check_finite = False
2239
+ )
2240
+ assert x2 is not b2
2241
+
2242
+ for x in (x1 , x2 ):
2243
+ assert dpnp .allclose (
2244
+ a_dp @ x ,
2245
+ dpnp .array ([1 , 0 ], dtype = x .dtype ),
2246
+ rtol = 1e-6 ,
2247
+ atol = 1e-6 ,
2248
+ )
2249
+
2250
+ @pytest .mark .parametrize (
2251
+ "dtype_a" , get_all_dtypes (no_bool = True , no_none = True )
2252
+ )
2253
+ @pytest .mark .parametrize (
2254
+ "dtype_b" , get_all_dtypes (no_bool = True , no_none = True )
2255
+ )
2256
+ def test_diff_type (self , dtype_a , dtype_b ):
2257
+ a_np = self ._make_nonsingular_np ((3 , 3 ), dtype_a , order = "F" )
2258
+ a_dp = dpnp .array (a_np , order = "F" )
2259
+
2260
+ b_np = generate_random_numpy_array ((3 ,), dtype_b , order = "F" )
2261
+ b_dp = dpnp .array (b_np , order = "F" )
2262
+
2263
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2264
+ x = dpnp .linalg .lu_solve ((lu , piv ), b_dp , check_finite = False )
2265
+ assert dpnp .allclose (
2266
+ a_dp @ x , b_dp .astype (x .dtype , copy = False ), rtol = 1e-6 , atol = 1e-6
2267
+ )
2268
+
2269
+ def test_strided_rhs (self ):
2270
+ n = 7
2271
+ a_np = self ._make_nonsingular_np (
2272
+ (n , n ), dpnp .default_float_type (), order = "F"
2273
+ )
2274
+ a_dp = dpnp .array (a_np , order = "F" )
2275
+
2276
+ rhs_full = (
2277
+ dpnp .arange (n * n , dtype = dpnp .default_float_type ()).reshape (
2278
+ n , n , order = "F"
2279
+ )
2280
+ + 1.0
2281
+ )
2282
+ b_dp = rhs_full [:, ::2 ][:, :3 ]
2283
+
2284
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2285
+ x = dpnp .linalg .lu_solve (
2286
+ (lu , piv ), b_dp , overwrite_b = False , check_finite = False
2287
+ )
2288
+
2289
+ assert dpnp .allclose (a_dp @ x , b_dp , rtol = 1e-6 , atol = 1e-6 )
2290
+
2291
+ @pytest .mark .skip ("Not implemented yet" )
2292
+ @pytest .mark .parametrize (
2293
+ "b_shape" ,
2294
+ [
2295
+ (4 ,),
2296
+ (4 , 1 ),
2297
+ (4 , 3 ),
2298
+ # (1, 4, 3),
2299
+ # (2, 4, 3),
2300
+ # (1, 1, 4, 3)
2301
+ ],
2302
+ )
2303
+ def test_broadcast_rhs (self , b_shape ):
2304
+ dtype = dpnp .default_float_type ()
2305
+
2306
+ a_np = self ._make_nonsingular_np ((4 , 4 ), dtype , order = "F" )
2307
+ a_dp = dpnp .array (a_np , order = "F" )
2308
+
2309
+ b_np = generate_random_numpy_array (b_shape , dtype , order = "F" )
2310
+ b_dp = dpnp .array (b_np , order = "F" )
2311
+
2312
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2313
+ x = dpnp .linalg .lu_solve (
2314
+ (lu , piv ), b_dp , overwrite_b = True , check_finite = False
2315
+ )
2316
+
2317
+ assert x .shape == b_dp .shape
2318
+
2319
+ assert dpnp .allclose (a_dp @ x , b_dp , rtol = 1e-6 , atol = 1e-6 )
2320
+
2321
+ @pytest .mark .parametrize ("shape" , [(0 , 0 ), (0 , 5 ), (5 , 5 )])
2322
+ @pytest .mark .parametrize ("rhs_cols" , [None , 0 , 3 ])
2323
+ def test_empty_shapes (self , shape , rhs_cols ):
2324
+ a_dp = dpnp .empty (shape , dtype = dpnp .default_float_type (), order = "F" )
2325
+ if min (shape ) > 0 :
2326
+ for i in range (min (shape )):
2327
+ a_dp [i , i ] = a_dp .dtype .type (1.0 )
2328
+
2329
+ n = shape [0 ]
2330
+ if rhs_cols is None :
2331
+ b_shape = (n ,)
2332
+ else :
2333
+ b_shape = (n , rhs_cols )
2334
+ b_dp = dpnp .empty (b_shape , dtype = dpnp .default_float_type (), order = "F" )
2335
+
2336
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2337
+ x = dpnp .linalg .lu_solve ((lu , piv ), b_dp , check_finite = False )
2338
+
2339
+ assert x .shape == b_shape
2340
+
2341
+ @pytest .mark .parametrize ("bad" , [numpy .inf , - numpy .inf , numpy .nan ])
2342
+ def test_check_finite_raises (self , bad ):
2343
+ a_dp = dpnp .array ([[1.0 , 0.0 ], [0.0 , 1.0 ]], order = "F" )
2344
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2345
+
2346
+ b_bad = dpnp .array ([1.0 , bad ], order = "F" )
2347
+ assert_raises (
2348
+ ValueError ,
2349
+ dpnp .linalg .lu_solve ,
2350
+ (lu , piv ),
2351
+ b_bad ,
2352
+ check_finite = True ,
2353
+ )
2354
+
2355
+
2147
2356
class TestMatrixPower :
2148
2357
@pytest .mark .parametrize ("dtype" , get_all_dtypes ())
2149
2358
@pytest .mark .parametrize (
0 commit comments