@@ -2300,9 +2300,6 @@ def test_strided_rhs(self):
2300
2300
(4 ,),
2301
2301
(4 , 1 ),
2302
2302
(4 , 3 ),
2303
- # (1, 4, 3),
2304
- # (2, 4, 3),
2305
- # (1, 1, 4, 3)
2306
2303
],
2307
2304
)
2308
2305
def test_broadcast_rhs (self , b_shape ):
@@ -2358,6 +2355,241 @@ def test_check_finite_raises(self, bad):
2358
2355
)
2359
2356
2360
2357
2358
+ class TestLuSolveBatched :
2359
+ @staticmethod
2360
+ def _make_nonsingular_nd_np (shape , dtype , order ):
2361
+ A = generate_random_numpy_array (shape , dtype , order )
2362
+ n = shape [- 1 ]
2363
+ A3 = A .reshape ((- 1 , n , n ))
2364
+ for B in A3 :
2365
+ off = numpy .sum (numpy .abs (B ), axis = 1 ) - numpy .abs (numpy .diag (B ))
2366
+ B [numpy .arange (n ), numpy .arange (n )] = A .dtype .type (off + 1.0 )
2367
+ A = A3 .reshape (shape )
2368
+ # Ensure reshapes did not break memory order
2369
+ A = numpy .array (A , order = order )
2370
+ return A
2371
+
2372
+ @staticmethod
2373
+ def _expected_x_shape (a_shape , b_shape ):
2374
+ n = a_shape [- 1 ]
2375
+ assert a_shape [- 2 ] == n
2376
+
2377
+ a_batch = a_shape [:- 2 ]
2378
+ if len (b_shape ) >= 2 and b_shape [- 2 ] == n :
2379
+ # b : (..., n, nrhs)
2380
+ k = b_shape [- 1 ]
2381
+ b_batch = b_shape [:- 2 ]
2382
+ exp_batch = numpy .broadcast_shapes (a_batch , b_batch )
2383
+ return exp_batch + (n , k )
2384
+ else :
2385
+ # b : (..., n)
2386
+ assert b_shape [- 1 ] == n , "b's last dim must equal n"
2387
+ b_batch = b_shape [:- 1 ]
2388
+ exp_batch = numpy .broadcast_shapes (a_batch , b_batch )
2389
+ return exp_batch + (n ,)
2390
+
2391
+ @pytest .mark .parametrize (
2392
+ "a_shape, b_shape" ,
2393
+ [
2394
+ ((1 , 2 , 2 ), (2 ,)),
2395
+ ((2 , 4 , 4 ), (4 ,)),
2396
+ ((2 , 4 , 4 ), (4 , 3 )),
2397
+ ((2 , 4 , 4 ), (2 , 4 , 4 )),
2398
+ ((2 , 4 , 4 ), (1 , 4 , 3 )),
2399
+ ((2 , 4 , 4 ), (2 , 4 , 2 )),
2400
+ ((2 , 3 , 4 , 4 ), (1 , 3 , 4 , 2 )),
2401
+ ((2 , 3 , 4 , 4 ), (2 , 1 , 4 , 2 )),
2402
+ ((3 , 4 , 4 ), (1 , 1 , 4 , 2 )),
2403
+ ],
2404
+ )
2405
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2406
+ @pytest .mark .parametrize (
2407
+ "dtype" , get_all_dtypes (no_bool = True , no_none = True )
2408
+ )
2409
+ def test_lu_solve_batched (self , a_shape , b_shape , dtype , order ):
2410
+ a_np = self ._make_nonsingular_nd_np (a_shape , dtype , order )
2411
+ a_dp = dpnp .array (a_np , order = order )
2412
+
2413
+ b_np = generate_random_numpy_array (b_shape , dtype , order )
2414
+ b_dp = dpnp .array (b_np , order = order )
2415
+
2416
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2417
+ x = dpnp .linalg .lu_solve (
2418
+ (lu , piv ), b_dp , overwrite_b = True , check_finite = False
2419
+ )
2420
+
2421
+ exp_shape = self ._expected_x_shape (a_shape , b_shape )
2422
+ assert x .shape == exp_shape
2423
+
2424
+ if b_dp .ndim > 1 :
2425
+ Ax = a_dp @ x
2426
+ else :
2427
+ Ax = (a_dp @ x [..., None ])[..., 0 ]
2428
+ b_exp = dpnp .broadcast_to (b_dp , exp_shape )
2429
+ assert dpnp .allclose (Ax , b_exp , rtol = 1e-5 , atol = 1e-5 )
2430
+
2431
+ @pytest .mark .parametrize ("trans" , [0 , 1 , 2 ])
2432
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2433
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2434
+ def test_trans (self , trans , order , dtype ):
2435
+ a_shape = (3 , 4 , 4 )
2436
+ b_shape = (3 , 4 , 2 )
2437
+
2438
+ a_np = self ._make_nonsingular_nd_np (a_shape , dtype , order )
2439
+ a_dp = dpnp .array (a_np , order = order )
2440
+ b_dp = dpnp .array (
2441
+ generate_random_numpy_array (b_shape , dtype , order ), order = order
2442
+ )
2443
+
2444
+ lu , piv = dpnp .linalg .lu_factor (
2445
+ a_dp , overwrite_a = False , check_finite = False
2446
+ )
2447
+ x = dpnp .linalg .lu_solve (
2448
+ (lu , piv ), b_dp , trans = trans , overwrite_b = False , check_finite = False
2449
+ )
2450
+
2451
+ if trans == 0 :
2452
+ lhs = a_dp @ x
2453
+ elif trans == 1 :
2454
+ lhs = dpnp .swapaxes (a_dp , - 1 , - 2 ) @ x
2455
+ else : # trans == 2
2456
+ lhs = dpnp .conj (dpnp .swapaxes (a_dp , - 1 , - 2 )) @ x
2457
+
2458
+ assert dpnp .allclose (lhs , b_dp , rtol = 1e-5 , atol = 1e-5 )
2459
+
2460
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2461
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2462
+ def test_overwrite (self , dtype , order ):
2463
+ a_np = self ._make_nonsingular_nd_np ((2 , 4 , 4 ), dtype , order )
2464
+ a_dp = dpnp .array (a_np , order = order )
2465
+
2466
+ lu , piv = dpnp .linalg .lu_factor (
2467
+ a_dp , overwrite_a = False , check_finite = False
2468
+ )
2469
+
2470
+ b_dp = dpnp .array (
2471
+ generate_random_numpy_array ((2 , 4 , 2 ), dtype , "F" ), order = "F"
2472
+ )
2473
+ b_dp_orig = b_dp .copy ()
2474
+ x = dpnp .linalg .lu_solve (
2475
+ (lu , piv ), b_dp , overwrite_b = True , check_finite = False
2476
+ )
2477
+
2478
+ assert x is not b_dp
2479
+ assert dpnp .allclose (b_dp , b_dp_orig )
2480
+
2481
+ assert dpnp .allclose (a_dp @ x , b_dp , rtol = 1e-5 , atol = 1e-5 )
2482
+
2483
+ def test_strided (self ):
2484
+ n , B = 4 , 6
2485
+ a_np = self ._make_nonsingular_nd_np (
2486
+ (B , n , n ), dpnp .default_float_type (), "F"
2487
+ )
2488
+ a_dp = dpnp .array (a_np , order = "F" )
2489
+
2490
+ a_stride = a_dp [::2 ]
2491
+ rhs_full = (
2492
+ dpnp .arange (B * n * 3 , dtype = dpnp .default_float_type ()).reshape (
2493
+ B , n , 3 , order = "F"
2494
+ )
2495
+ + 1.0
2496
+ )
2497
+ b_dp = rhs_full [::2 , :, ::- 1 ]
2498
+
2499
+ lu , piv = dpnp .linalg .lu_factor (a_stride , check_finite = False )
2500
+ x = dpnp .linalg .lu_solve (
2501
+ (lu , piv ), b_dp , overwrite_b = False , check_finite = False
2502
+ )
2503
+
2504
+ assert dpnp .allclose (a_stride @ x , b_dp , rtol = 1e-5 , atol = 1e-5 )
2505
+
2506
+ @pytest .mark .parametrize (
2507
+ "dtype_a" , get_all_dtypes (no_bool = True , no_none = True )
2508
+ )
2509
+ @pytest .mark .parametrize (
2510
+ "dtype_b" , get_all_dtypes (no_bool = True , no_none = True )
2511
+ )
2512
+ @pytest .mark .parametrize ("b_shape" , [(4 , 2 ), (1 , 4 , 2 ), (2 , 4 , 2 )])
2513
+ def test_diff_type (self , dtype_a , dtype_b , b_shape ):
2514
+ B , n , k = 2 , 4 , 2
2515
+ a_np = self ._make_nonsingular_nd_np ((B , n , n ), dtype_a , "F" )
2516
+ a_dp = dpnp .array (a_np , order = "F" )
2517
+
2518
+ b_np = generate_random_numpy_array (b_shape , dtype_b , "F" )
2519
+ b_dp = dpnp .array (b_np , order = "F" )
2520
+
2521
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2522
+ x = dpnp .linalg .lu_solve ((lu , piv ), b_dp , check_finite = False )
2523
+
2524
+ exp_shape = (B , n , k )
2525
+ assert x .shape == exp_shape
2526
+
2527
+ b_exp = dpnp .broadcast_to (b_dp , exp_shape )
2528
+ assert dpnp .allclose (
2529
+ a_dp @ x , b_exp .astype (x .dtype , copy = False ), rtol = 1e-5 , atol = 1e-5
2530
+ )
2531
+
2532
+ @pytest .mark .parametrize (
2533
+ "a_shape, b_shape" ,
2534
+ [
2535
+ ((0 , 3 , 3 ), (0 , 3 )),
2536
+ ((2 , 0 , 0 ), (2 , 0 )),
2537
+ ((0 , 0 , 0 ), (0 , 0 )),
2538
+ ],
2539
+ )
2540
+ def test_empty_inputs (self , a_shape , b_shape ):
2541
+ a = dpnp .empty (a_shape , dtype = dpnp .default_float_type (), order = "F" )
2542
+ b = dpnp .empty (b_shape , dtype = dpnp .default_float_type (), order = "F" )
2543
+
2544
+ lu , piv = dpnp .linalg .lu_factor (a , check_finite = False )
2545
+ x = dpnp .linalg .lu_solve ((lu , piv ), b , check_finite = False )
2546
+
2547
+ assert x .shape == b_shape
2548
+
2549
+ def test_check_finite_raises (self ):
2550
+ B , n = 2 , 3
2551
+ a_np = self ._make_nonsingular_nd_np (
2552
+ (B , n , n ), dpnp .default_float_type (), "F"
2553
+ )
2554
+ a_dp = dpnp .array (a_np , order = "F" )
2555
+ lu , piv = dpnp .linalg .lu_factor (a_dp , check_finite = False )
2556
+
2557
+ b_bad = dpnp .ones ((B , n ), dtype = dpnp .default_float_type (), order = "F" )
2558
+ b_bad [1 , 0 ] = dpnp .nan
2559
+ assert_raises (
2560
+ ValueError ,
2561
+ dpnp .linalg .lu_solve ,
2562
+ (lu , piv ),
2563
+ b_bad ,
2564
+ check_finite = True ,
2565
+ )
2566
+
2567
+ @pytest .mark .parametrize (
2568
+ "a_shape, b_shape" ,
2569
+ [
2570
+ ((2 , 4 , 4 ), (2 ,)),
2571
+ ((2 , 4 , 4 ), (2 , 4 )),
2572
+ ((2 , 4 , 4 ), (4 , 4 , 2 )),
2573
+ ((2 , 4 , 4 ), (2 , 3 , 4 , 2 )),
2574
+ ((2 , 3 , 4 , 4 ), (3 , 4 )),
2575
+ ((2 , 3 , 4 , 4 ), (2 , 4 )),
2576
+ ((2 , 3 , 4 , 4 ), (2 , 3 , 5 , 2 )),
2577
+ ],
2578
+ )
2579
+ def test_invalid_shapes (self , a_shape , b_shape ):
2580
+ dtype = dpnp .default_float_type ()
2581
+ a = dpnp .array (
2582
+ self ._make_nonsingular_nd_np (a_shape , dtype , "F" ), order = "F"
2583
+ )
2584
+ b = dpnp .array (
2585
+ generate_random_numpy_array (b_shape , dtype , "F" ), order = "F"
2586
+ )
2587
+
2588
+ lu , piv = dpnp .linalg .lu_factor (a , check_finite = False )
2589
+ with pytest .raises (ValueError ):
2590
+ dpnp .linalg .lu_solve ((lu , piv ), b , check_finite = False )
2591
+
2592
+
2361
2593
class TestMatrixPower :
2362
2594
@pytest .mark .parametrize ("dtype" , get_all_dtypes ())
2363
2595
@pytest .mark .parametrize (
0 commit comments