@@ -2022,6 +2022,119 @@ def test_batched_not_supported(self):
20222022 assert_raises (NotImplementedError , dpnp .linalg .lu_factor , a_dp )
20232023
20242024
2025+ class TestLuFactorBatched :
2026+ @staticmethod
2027+ def _apply_pivots_rows (A_dp , piv_dp ):
2028+ m = A_dp .shape [0 ]
2029+ rows = dpnp .arange (m )
2030+ for i in range (int (piv_dp .shape [0 ])):
2031+ r = int (piv_dp [i ].item ())
2032+ if i != r :
2033+ tmp = rows [i ].copy ()
2034+ rows [i ] = rows [r ]
2035+ rows [r ] = tmp
2036+ return A_dp [rows ]
2037+
2038+ @staticmethod
2039+ def _split_lu (lu , m , n ):
2040+ L = dpnp .tril (lu , k = - 1 )
2041+ dpnp .fill_diagonal (L , 1 )
2042+ L = L [:, : min (m , n )]
2043+ U = dpnp .triu (lu )[: min (m , n ), :]
2044+ return L , U
2045+
2046+ @pytest .mark .parametrize (
2047+ "shape" ,
2048+ [(2 , 2 , 2 ), (3 , 4 , 4 ), (2 , 3 , 5 , 2 ), (4 , 1 , 3 )],
2049+ ids = ["(2,2,2)" , "(3,4,4)" , "(2,3,5,2)" , "(4,1,3)" ],
2050+ )
2051+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2052+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
2053+ def test_lu_factor_batched (self , shape , order , dtype ):
2054+ a_np = generate_random_numpy_array (shape , dtype , order )
2055+ a_dp = dpnp .array (a_np , order = order )
2056+
2057+ lu , piv = dpnp .linalg .lu_factor (
2058+ a_dp , check_finite = False , overwrite_a = False
2059+ )
2060+
2061+ assert lu .shape == a_dp .shape
2062+ m , n = shape [- 2 ], shape [- 1 ]
2063+ assert piv .shape == (* shape [:- 2 ], min (m , n ))
2064+ assert piv .dtype == dpnp .int64
2065+
2066+ a_3d = a_dp .reshape ((- 1 , m , n ))
2067+ lu_3d = lu .reshape ((- 1 , m , n ))
2068+ piv_2d = piv .reshape ((- 1 , min (m , n )))
2069+ for i in range (a_3d .shape [0 ]):
2070+ L , U = self ._split_lu (lu_3d [i ], m , n )
2071+ A_cast = a_3d [i ].astype (L .dtype , copy = False )
2072+ PA = self ._apply_pivots_rows (A_cast , piv_2d [i ])
2073+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2074+
2075+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2076+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2077+ def test_overwrite (self , dtype , order ):
2078+ a_dp = dpnp .arange (2 * 2 * 3 , dtype = dtype ).reshape (3 , 2 , 2 , order = order )
2079+ a_dp_orig = a_dp .copy ()
2080+ lu , piv = dpnp .linalg .lu_factor (
2081+ a_dp , overwrite_a = True , check_finite = False
2082+ )
2083+
2084+ assert lu is not a_dp
2085+ assert_allclose (a_dp , a_dp_orig )
2086+
2087+ m = n = 2
2088+ lu_3d = lu .reshape ((- 1 , m , n ))
2089+ a_3d = a_dp .reshape ((- 1 , m , n ))
2090+ piv_2d = piv .reshape ((- 1 , min (m , n )))
2091+ for i in range (a_3d .shape [0 ]):
2092+ L , U = self ._split_lu (lu_3d [i ], m , n )
2093+ A_cast = a_3d [i ].astype (L .dtype , copy = False )
2094+ PA = self ._apply_pivots_rows (A_cast , piv_2d [i ])
2095+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2096+
2097+ @pytest .mark .parametrize (
2098+ "shape" , [(0 , 2 , 2 ), (2 , 0 , 2 ), (2 , 2 , 0 ), (0 , 0 , 0 )]
2099+ )
2100+ def test_empty_inputs (self , shape ):
2101+ a = dpnp .empty (shape , dtype = dpnp .default_float_type (), order = "F" )
2102+
2103+ lu , piv = dpnp .linalg .lu_factor (a , check_finite = False )
2104+ assert lu .shape == shape
2105+ m , n = shape [- 2 :]
2106+ assert piv .shape == (* shape [:- 2 ], min (m , n ))
2107+
2108+ def test_strided (self ):
2109+ a = (
2110+ dpnp .arange (5 * 3 * 3 , dtype = dpnp .default_float_type ()).reshape (
2111+ 5 , 3 , 3 , order = "F"
2112+ )
2113+ + 0.1
2114+ )
2115+ a_stride = a [::2 ]
2116+ lu , piv = dpnp .linalg .lu_factor (a_stride , check_finite = False )
2117+ for i in range (a_stride .shape [0 ]):
2118+ L , U = self ._split_lu (lu [i ], 3 , 3 )
2119+ PA = self ._apply_pivots_rows (
2120+ a_stride [i ].astype (L .dtype , copy = False ), piv [i ]
2121+ )
2122+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2123+
2124+ def test_singular_matrix (self ):
2125+ a = dpnp .zeros ((3 , 2 , 2 ), dtype = dpnp .float64 )
2126+ a [0 ] = dpnp .array ([[1.0 , 2.0 ], [2.0 , 4.0 ]])
2127+ a [1 ] = dpnp .eye (2 )
2128+ a [2 ] = dpnp .array ([[1.0 , 1.0 ], [1.0 , 1.0 ]])
2129+ with pytest .warns (RuntimeWarning , match = "Singular matrix" ):
2130+ dpnp .linalg .lu_factor (a , check_finite = False )
2131+
2132+ def test_check_finite_raises (self ):
2133+ a = dpnp .ones ((2 , 3 , 3 ), dtype = dpnp .float64 , order = "F" )
2134+ a [1 , 0 , 0 ] = dpnp .nan
2135+ assert_raises (ValueError , dpnp .linalg .lu_factor , a , check_finite = True )
2136+
2137+
20252138class TestMatrixPower :
20262139 @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
20272140 @pytest .mark .parametrize (
0 commit comments