@@ -2018,6 +2018,119 @@ def test_batched_not_supported(self):
2018
2018
assert_raises (NotImplementedError , dpnp .linalg .lu_factor , a_dp )
2019
2019
2020
2020
2021
+ class TestLuFactorBatched :
2022
+ @staticmethod
2023
+ def _apply_pivots_rows (A_dp , piv_dp ):
2024
+ m = A_dp .shape [0 ]
2025
+ rows = dpnp .arange (m )
2026
+ for i in range (int (piv_dp .shape [0 ])):
2027
+ r = int (piv_dp [i ].item ())
2028
+ if i != r :
2029
+ tmp = rows [i ].copy ()
2030
+ rows [i ] = rows [r ]
2031
+ rows [r ] = tmp
2032
+ return A_dp [rows ]
2033
+
2034
+ @staticmethod
2035
+ def _split_lu (lu , m , n ):
2036
+ L = dpnp .tril (lu , k = - 1 )
2037
+ dpnp .fill_diagonal (L , 1 )
2038
+ L = L [:, : min (m , n )]
2039
+ U = dpnp .triu (lu )[: min (m , n ), :]
2040
+ return L , U
2041
+
2042
+ @pytest .mark .parametrize (
2043
+ "shape" ,
2044
+ [(2 , 2 , 2 ), (3 , 4 , 4 ), (2 , 3 , 5 , 2 ), (4 , 1 , 3 )],
2045
+ ids = ["(2,2,2)" , "(3,4,4)" , "(2,3,5,2)" , "(4,1,3)" ],
2046
+ )
2047
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2048
+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
2049
+ def test_lu_factor_batched (self , shape , order , dtype ):
2050
+ a_np = generate_random_numpy_array (shape , dtype , order )
2051
+ a_dp = dpnp .array (a_np , order = order )
2052
+
2053
+ lu , piv = dpnp .linalg .lu_factor (
2054
+ a_dp , check_finite = False , overwrite_a = False
2055
+ )
2056
+
2057
+ assert lu .shape == a_dp .shape
2058
+ m , n = shape [- 2 ], shape [- 1 ]
2059
+ assert piv .shape == (* shape [:- 2 ], min (m , n ))
2060
+ assert piv .dtype == dpnp .int64
2061
+
2062
+ a_3d = a_dp .reshape ((- 1 , m , n ))
2063
+ lu_3d = lu .reshape ((- 1 , m , n ))
2064
+ piv_2d = piv .reshape ((- 1 , min (m , n )))
2065
+ for i in range (a_3d .shape [0 ]):
2066
+ L , U = self ._split_lu (lu_3d [i ], m , n )
2067
+ A_cast = a_3d [i ].astype (L .dtype , copy = False )
2068
+ PA = self ._apply_pivots_rows (A_cast , piv_2d [i ])
2069
+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2070
+
2071
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2072
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2073
+ def test_overwrite (self , dtype , order ):
2074
+ a_dp = dpnp .arange (2 * 2 * 3 , dtype = dtype ).reshape (3 , 2 , 2 , order = order )
2075
+ a_dp_orig = a_dp .copy ()
2076
+ lu , piv = dpnp .linalg .lu_factor (
2077
+ a_dp , overwrite_a = True , check_finite = False
2078
+ )
2079
+
2080
+ assert lu is not a_dp
2081
+ assert_allclose (a_dp , a_dp_orig )
2082
+
2083
+ m = n = 2
2084
+ lu_3d = lu .reshape ((- 1 , m , n ))
2085
+ a_3d = a_dp .reshape ((- 1 , m , n ))
2086
+ piv_2d = piv .reshape ((- 1 , min (m , n )))
2087
+ for i in range (a_3d .shape [0 ]):
2088
+ L , U = self ._split_lu (lu_3d [i ], m , n )
2089
+ A_cast = a_3d [i ].astype (L .dtype , copy = False )
2090
+ PA = self ._apply_pivots_rows (A_cast , piv_2d [i ])
2091
+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2092
+
2093
+ @pytest .mark .parametrize (
2094
+ "shape" , [(0 , 2 , 2 ), (2 , 0 , 2 ), (2 , 2 , 0 ), (0 , 0 , 0 )]
2095
+ )
2096
+ def test_empty_inputs (self , shape ):
2097
+ a = dpnp .empty (shape , dtype = dpnp .default_float_type (), order = "F" )
2098
+
2099
+ lu , piv = dpnp .linalg .lu_factor (a , check_finite = False )
2100
+ assert lu .shape == shape
2101
+ m , n = shape [- 2 :]
2102
+ assert piv .shape == (* shape [:- 2 ], min (m , n ))
2103
+
2104
+ def test_strided (self ):
2105
+ a = (
2106
+ dpnp .arange (5 * 3 * 3 , dtype = dpnp .default_float_type ()).reshape (
2107
+ 5 , 3 , 3 , order = "F"
2108
+ )
2109
+ + 0.1
2110
+ )
2111
+ a_stride = a [::2 ]
2112
+ lu , piv = dpnp .linalg .lu_factor (a_stride , check_finite = False )
2113
+ for i in range (a_stride .shape [0 ]):
2114
+ L , U = self ._split_lu (lu [i ], 3 , 3 )
2115
+ PA = self ._apply_pivots_rows (
2116
+ a_stride [i ].astype (L .dtype , copy = False ), piv [i ]
2117
+ )
2118
+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2119
+
2120
+ def test_singular_matrix (self ):
2121
+ a = dpnp .zeros ((3 , 2 , 2 ), dtype = dpnp .float64 )
2122
+ a [0 ] = dpnp .array ([[1.0 , 2.0 ], [2.0 , 4.0 ]])
2123
+ a [1 ] = dpnp .eye (2 )
2124
+ a [2 ] = dpnp .array ([[1.0 , 1.0 ], [1.0 , 1.0 ]])
2125
+ with pytest .warns (RuntimeWarning , match = "Singular matrix" ):
2126
+ dpnp .linalg .lu_factor (a , check_finite = False )
2127
+
2128
+ def test_check_finite_raises (self ):
2129
+ a = dpnp .ones ((2 , 3 , 3 ), dtype = dpnp .float64 , order = "F" )
2130
+ a [1 , 0 , 0 ] = dpnp .nan
2131
+ assert_raises (ValueError , dpnp .linalg .lu_factor , a , check_finite = True )
2132
+
2133
+
2021
2134
class TestMatrixPower :
2022
2135
@pytest .mark .parametrize ("dtype" , get_all_dtypes ())
2023
2136
@pytest .mark .parametrize (
0 commit comments