@@ -2018,6 +2018,119 @@ def test_batched_not_supported(self):
20182018 assert_raises (NotImplementedError , dpnp .linalg .lu_factor , a_dp )
20192019
20202020
2021+ class TestLuFactorBatched :
2022+ @staticmethod
2023+ def _apply_pivots_rows (A_dp , piv_dp ):
2024+ m = A_dp .shape [0 ]
2025+ rows = dpnp .arange (m )
2026+ for i in range (int (piv_dp .shape [0 ])):
2027+ r = int (piv_dp [i ].item ())
2028+ if i != r :
2029+ tmp = rows [i ].copy ()
2030+ rows [i ] = rows [r ]
2031+ rows [r ] = tmp
2032+ return A_dp [rows ]
2033+
2034+ @staticmethod
2035+ def _split_lu (lu , m , n ):
2036+ L = dpnp .tril (lu , k = - 1 )
2037+ dpnp .fill_diagonal (L , 1 )
2038+ L = L [:, : min (m , n )]
2039+ U = dpnp .triu (lu )[: min (m , n ), :]
2040+ return L , U
2041+
2042+ @pytest .mark .parametrize (
2043+ "shape" ,
2044+ [(2 , 2 , 2 ), (3 , 4 , 4 ), (2 , 3 , 5 , 2 ), (4 , 1 , 3 )],
2045+ ids = ["(2,2,2)" , "(3,4,4)" , "(2,3,5,2)" , "(4,1,3)" ],
2046+ )
2047+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2048+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
2049+ def test_lu_factor_batched (self , shape , order , dtype ):
2050+ a_np = generate_random_numpy_array (shape , dtype , order )
2051+ a_dp = dpnp .array (a_np , order = order )
2052+
2053+ lu , piv = dpnp .linalg .lu_factor (
2054+ a_dp , check_finite = False , overwrite_a = False
2055+ )
2056+
2057+ assert lu .shape == a_dp .shape
2058+ m , n = shape [- 2 ], shape [- 1 ]
2059+ assert piv .shape == (* shape [:- 2 ], min (m , n ))
2060+ assert piv .dtype == dpnp .int64
2061+
2062+ a_3d = a_dp .reshape ((- 1 , m , n ))
2063+ lu_3d = lu .reshape ((- 1 , m , n ))
2064+ piv_2d = piv .reshape ((- 1 , min (m , n )))
2065+ for i in range (a_3d .shape [0 ]):
2066+ L , U = self ._split_lu (lu_3d [i ], m , n )
2067+ A_cast = a_3d [i ].astype (L .dtype , copy = False )
2068+ PA = self ._apply_pivots_rows (A_cast , piv_2d [i ])
2069+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2070+
2071+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2072+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2073+ def test_overwrite (self , dtype , order ):
2074+ a_dp = dpnp .arange (2 * 2 * 3 , dtype = dtype ).reshape (3 , 2 , 2 , order = order )
2075+ a_dp_orig = a_dp .copy ()
2076+ lu , piv = dpnp .linalg .lu_factor (
2077+ a_dp , overwrite_a = True , check_finite = False
2078+ )
2079+
2080+ assert lu is not a_dp
2081+ assert_allclose (a_dp , a_dp_orig )
2082+
2083+ m = n = 2
2084+ lu_3d = lu .reshape ((- 1 , m , n ))
2085+ a_3d = a_dp .reshape ((- 1 , m , n ))
2086+ piv_2d = piv .reshape ((- 1 , min (m , n )))
2087+ for i in range (a_3d .shape [0 ]):
2088+ L , U = self ._split_lu (lu_3d [i ], m , n )
2089+ A_cast = a_3d [i ].astype (L .dtype , copy = False )
2090+ PA = self ._apply_pivots_rows (A_cast , piv_2d [i ])
2091+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2092+
2093+ @pytest .mark .parametrize (
2094+ "shape" , [(0 , 2 , 2 ), (2 , 0 , 2 ), (2 , 2 , 0 ), (0 , 0 , 0 )]
2095+ )
2096+ def test_empty_inputs (self , shape ):
2097+ a = dpnp .empty (shape , dtype = dpnp .default_float_type (), order = "F" )
2098+
2099+ lu , piv = dpnp .linalg .lu_factor (a , check_finite = False )
2100+ assert lu .shape == shape
2101+ m , n = shape [- 2 :]
2102+ assert piv .shape == (* shape [:- 2 ], min (m , n ))
2103+
2104+ def test_strided (self ):
2105+ a = (
2106+ dpnp .arange (5 * 3 * 3 , dtype = dpnp .default_float_type ()).reshape (
2107+ 5 , 3 , 3 , order = "F"
2108+ )
2109+ + 0.1
2110+ )
2111+ a_stride = a [::2 ]
2112+ lu , piv = dpnp .linalg .lu_factor (a_stride , check_finite = False )
2113+ for i in range (a_stride .shape [0 ]):
2114+ L , U = self ._split_lu (lu [i ], 3 , 3 )
2115+ PA = self ._apply_pivots_rows (
2116+ a_stride [i ].astype (L .dtype , copy = False ), piv [i ]
2117+ )
2118+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2119+
2120+ def test_singular_matrix (self ):
2121+ a = dpnp .zeros ((3 , 2 , 2 ), dtype = dpnp .float64 )
2122+ a [0 ] = dpnp .array ([[1.0 , 2.0 ], [2.0 , 4.0 ]])
2123+ a [1 ] = dpnp .eye (2 )
2124+ a [2 ] = dpnp .array ([[1.0 , 1.0 ], [1.0 , 1.0 ]])
2125+ with pytest .warns (RuntimeWarning , match = "Singular matrix" ):
2126+ dpnp .linalg .lu_factor (a , check_finite = False )
2127+
2128+ def test_check_finite_raises (self ):
2129+ a = dpnp .ones ((2 , 3 , 3 ), dtype = dpnp .float64 , order = "F" )
2130+ a [1 , 0 , 0 ] = dpnp .nan
2131+ assert_raises (ValueError , dpnp .linalg .lu_factor , a , check_finite = True )
2132+
2133+
20212134class TestMatrixPower :
20222135 @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
20232136 @pytest .mark .parametrize (
0 commit comments