@@ -2300,9 +2300,6 @@ def test_strided_rhs(self):
23002300 (4 ,),
23012301 (4 , 1 ),
23022302 (4 , 3 ),
2303- # (1, 4, 3),
2304- # (2, 4, 3),
2305- # (1, 1, 4, 3)
23062303 ],
23072304 )
23082305 def test_broadcast_rhs (self , b_shape ):
@@ -2358,6 +2355,241 @@ def test_check_finite_raises(self, bad):
23582355 )
23592356
23602357
2358+ class TestLuSolveBatched :
2359+ @staticmethod
2360+ def _make_nonsingular_nd_np (shape , dtype , order ):
2361+ A = generate_random_numpy_array (shape , dtype , order )
2362+ n = shape [- 1 ]
2363+ A3 = A .reshape ((- 1 , n , n ))
2364+ for B in A3 :
2365+ off = numpy .sum (numpy .abs (B ), axis = 1 ) - numpy .abs (numpy .diag (B ))
2366+ B [numpy .arange (n ), numpy .arange (n )] = A .dtype .type (off + 1.0 )
2367+ A = A3 .reshape (shape )
2368+ # Ensure reshapes did not break memory order
2369+ A = numpy .array (A , order = order )
2370+ return A
2371+
2372+ @staticmethod
2373+ def _expected_x_shape (a_shape , b_shape ):
2374+ n = a_shape [- 1 ]
2375+ assert a_shape [- 2 ] == n
2376+
2377+ a_batch = a_shape [:- 2 ]
2378+ if len (b_shape ) >= 2 and b_shape [- 2 ] == n :
2379+ # b : (..., n, nrhs)
2380+ k = b_shape [- 1 ]
2381+ b_batch = b_shape [:- 2 ]
2382+ exp_batch = numpy .broadcast_shapes (a_batch , b_batch )
2383+ return exp_batch + (n , k )
2384+ else :
2385+ # b : (..., n)
2386+ assert b_shape [- 1 ] == n , "b's last dim must equal n"
2387+ b_batch = b_shape [:- 1 ]
2388+ exp_batch = numpy .broadcast_shapes (a_batch , b_batch )
2389+ return exp_batch + (n ,)
2390+
2391+ @pytest .mark .parametrize (
2392+ "a_shape, b_shape" ,
2393+ [
2394+ ((1 , 2 , 2 ), (2 ,)),
2395+ ((2 , 4 , 4 ), (4 ,)),
2396+ ((2 , 4 , 4 ), (4 , 3 )),
2397+ ((2 , 4 , 4 ), (2 , 4 , 4 )),
2398+ ((2 , 4 , 4 ), (1 , 4 , 3 )),
2399+ ((2 , 4 , 4 ), (2 , 4 , 2 )),
2400+ ((2 , 3 , 4 , 4 ), (1 , 3 , 4 , 2 )),
2401+ ((2 , 3 , 4 , 4 ), (2 , 1 , 4 , 2 )),
2402+ ((3 , 4 , 4 ), (1 , 1 , 4 , 2 )),
2403+ ],
2404+ )
2405+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2406+ @pytest .mark .parametrize (
2407+ "dtype" , get_all_dtypes (no_bool = True , no_none = True )
2408+ )
2409+ def test_lu_solve_batched (self , a_shape , b_shape , dtype , order ):
2410+ a_np = self ._make_nonsingular_nd_np (a_shape , dtype , order )
2411+ a_dp = dpnp .array (a_np , order = order )
2412+
2413+ b_np = generate_random_numpy_array (b_shape , dtype , order )
2414+ b_dp = dpnp .array (b_np , order = order )
2415+
2416+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2417+ x = dpnp .linalg .lu_solve (
2418+ (lu , piv ), b_dp , overwrite_b = True , check_finite = False
2419+ )
2420+
2421+ exp_shape = self ._expected_x_shape (a_shape , b_shape )
2422+ assert x .shape == exp_shape
2423+
2424+ if b_dp .ndim > 1 :
2425+ Ax = a_dp @ x
2426+ else :
2427+ Ax = (a_dp @ x [..., None ])[..., 0 ]
2428+ b_exp = dpnp .broadcast_to (b_dp , exp_shape )
2429+ assert dpnp .allclose (Ax , b_exp , rtol = 1e-5 , atol = 1e-5 )
2430+
2431+ @pytest .mark .parametrize ("trans" , [0 , 1 , 2 ])
2432+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2433+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2434+ def test_trans (self , trans , order , dtype ):
2435+ a_shape = (3 , 4 , 4 )
2436+ b_shape = (3 , 4 , 2 )
2437+
2438+ a_np = self ._make_nonsingular_nd_np (a_shape , dtype , order )
2439+ a_dp = dpnp .array (a_np , order = order )
2440+ b_dp = dpnp .array (
2441+ generate_random_numpy_array (b_shape , dtype , order ), order = order
2442+ )
2443+
2444+ lu , piv = dpnp .linalg .lu_factor (
2445+ a_dp , overwrite_a = False , check_finite = False
2446+ )
2447+ x = dpnp .linalg .lu_solve (
2448+ (lu , piv ), b_dp , trans = trans , overwrite_b = False , check_finite = False
2449+ )
2450+
2451+ if trans == 0 :
2452+ lhs = a_dp @ x
2453+ elif trans == 1 :
2454+ lhs = dpnp .swapaxes (a_dp , - 1 , - 2 ) @ x
2455+ else : # trans == 2
2456+ lhs = dpnp .conj (dpnp .swapaxes (a_dp , - 1 , - 2 )) @ x
2457+
2458+ assert dpnp .allclose (lhs , b_dp , rtol = 1e-5 , atol = 1e-5 )
2459+
2460+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2461+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2462+ def test_overwrite (self , dtype , order ):
2463+ a_np = self ._make_nonsingular_nd_np ((2 , 4 , 4 ), dtype , order )
2464+ a_dp = dpnp .array (a_np , order = order )
2465+
2466+ lu , piv = dpnp .linalg .lu_factor (
2467+ a_dp , overwrite_a = False , check_finite = False
2468+ )
2469+
2470+ b_dp = dpnp .array (
2471+ generate_random_numpy_array ((2 , 4 , 2 ), dtype , "F" ), order = "F"
2472+ )
2473+ b_dp_orig = b_dp .copy ()
2474+ x = dpnp .linalg .lu_solve (
2475+ (lu , piv ), b_dp , overwrite_b = True , check_finite = False
2476+ )
2477+
2478+ assert x is not b_dp
2479+ assert dpnp .allclose (b_dp , b_dp_orig )
2480+
2481+ assert dpnp .allclose (a_dp @ x , b_dp , rtol = 1e-5 , atol = 1e-5 )
2482+
2483+ def test_strided (self ):
2484+ n , B = 4 , 6
2485+ a_np = self ._make_nonsingular_nd_np (
2486+ (B , n , n ), dpnp .default_float_type (), "F"
2487+ )
2488+ a_dp = dpnp .array (a_np , order = "F" )
2489+
2490+ a_stride = a_dp [::2 ]
2491+ rhs_full = (
2492+ dpnp .arange (B * n * 3 , dtype = dpnp .default_float_type ()).reshape (
2493+ B , n , 3 , order = "F"
2494+ )
2495+ + 1.0
2496+ )
2497+ b_dp = rhs_full [::2 , :, ::- 1 ]
2498+
2499+ lu , piv = dpnp .linalg .lu_factor (a_stride , check_finite = False )
2500+ x = dpnp .linalg .lu_solve (
2501+ (lu , piv ), b_dp , overwrite_b = False , check_finite = False
2502+ )
2503+
2504+ assert dpnp .allclose (a_stride @ x , b_dp , rtol = 1e-5 , atol = 1e-5 )
2505+
2506+ @pytest .mark .parametrize (
2507+ "dtype_a" , get_all_dtypes (no_bool = True , no_none = True )
2508+ )
2509+ @pytest .mark .parametrize (
2510+ "dtype_b" , get_all_dtypes (no_bool = True , no_none = True )
2511+ )
2512+ @pytest .mark .parametrize ("b_shape" , [(4 , 2 ), (1 , 4 , 2 ), (2 , 4 , 2 )])
2513+ def test_diff_type (self , dtype_a , dtype_b , b_shape ):
2514+ B , n , k = 2 , 4 , 2
2515+ a_np = self ._make_nonsingular_nd_np ((B , n , n ), dtype_a , "F" )
2516+ a_dp = dpnp .array (a_np , order = "F" )
2517+
2518+ b_np = generate_random_numpy_array (b_shape , dtype_b , "F" )
2519+ b_dp = dpnp .array (b_np , order = "F" )
2520+
2521+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2522+ x = dpnp .linalg .lu_solve ((lu , piv ), b_dp , check_finite = False )
2523+
2524+ exp_shape = (B , n , k )
2525+ assert x .shape == exp_shape
2526+
2527+ b_exp = dpnp .broadcast_to (b_dp , exp_shape )
2528+ assert dpnp .allclose (
2529+ a_dp @ x , b_exp .astype (x .dtype , copy = False ), rtol = 1e-5 , atol = 1e-5
2530+ )
2531+
2532+ @pytest .mark .parametrize (
2533+ "a_shape, b_shape" ,
2534+ [
2535+ ((0 , 3 , 3 ), (0 , 3 )),
2536+ ((2 , 0 , 0 ), (2 , 0 )),
2537+ ((0 , 0 , 0 ), (0 , 0 )),
2538+ ],
2539+ )
2540+ def test_empty_inputs (self , a_shape , b_shape ):
2541+ a = dpnp .empty (a_shape , dtype = dpnp .default_float_type (), order = "F" )
2542+ b = dpnp .empty (b_shape , dtype = dpnp .default_float_type (), order = "F" )
2543+
2544+ lu , piv = dpnp .linalg .lu_factor (a , check_finite = False )
2545+ x = dpnp .linalg .lu_solve ((lu , piv ), b , check_finite = False )
2546+
2547+ assert x .shape == b_shape
2548+
2549+ def test_check_finite_raises (self ):
2550+ B , n = 2 , 3
2551+ a_np = self ._make_nonsingular_nd_np (
2552+ (B , n , n ), dpnp .default_float_type (), "F"
2553+ )
2554+ a_dp = dpnp .array (a_np , order = "F" )
2555+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2556+
2557+ b_bad = dpnp .ones ((B , n ), dtype = dpnp .default_float_type (), order = "F" )
2558+ b_bad [1 , 0 ] = dpnp .nan
2559+ assert_raises (
2560+ ValueError ,
2561+ dpnp .linalg .lu_solve ,
2562+ (lu , piv ),
2563+ b_bad ,
2564+ check_finite = True ,
2565+ )
2566+
2567+ @pytest .mark .parametrize (
2568+ "a_shape, b_shape" ,
2569+ [
2570+ ((2 , 4 , 4 ), (2 ,)),
2571+ ((2 , 4 , 4 ), (2 , 4 )),
2572+ ((2 , 4 , 4 ), (4 , 4 , 2 )),
2573+ ((2 , 4 , 4 ), (2 , 3 , 4 , 2 )),
2574+ ((2 , 3 , 4 , 4 ), (3 , 4 )),
2575+ ((2 , 3 , 4 , 4 ), (2 , 4 )),
2576+ ((2 , 3 , 4 , 4 ), (2 , 3 , 5 , 2 )),
2577+ ],
2578+ )
2579+ def test_invalid_shapes (self , a_shape , b_shape ):
2580+ dtype = dpnp .default_float_type ()
2581+ a = dpnp .array (
2582+ self ._make_nonsingular_nd_np (a_shape , dtype , "F" ), order = "F"
2583+ )
2584+ b = dpnp .array (
2585+ generate_random_numpy_array (b_shape , dtype , "F" ), order = "F"
2586+ )
2587+
2588+ lu , piv = dpnp .linalg .lu_factor (a , check_finite = False )
2589+ with pytest .raises (ValueError ):
2590+ dpnp .linalg .lu_solve ((lu , piv ), b , check_finite = False )
2591+
2592+
23612593class TestMatrixPower :
23622594 @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
23632595 @pytest .mark .parametrize (
0 commit comments