@@ -2022,6 +2022,119 @@ def test_batched_not_supported(self):
2022
2022
assert_raises (NotImplementedError , dpnp .linalg .lu_factor , a_dp )
2023
2023
2024
2024
2025
+ class TestLuFactorBatched :
2026
+ @staticmethod
2027
+ def _apply_pivots_rows (A_dp , piv_dp ):
2028
+ m = A_dp .shape [0 ]
2029
+ rows = dpnp .arange (m )
2030
+ for i in range (int (piv_dp .shape [0 ])):
2031
+ r = int (piv_dp [i ].item ())
2032
+ if i != r :
2033
+ tmp = rows [i ].copy ()
2034
+ rows [i ] = rows [r ]
2035
+ rows [r ] = tmp
2036
+ return A_dp [rows ]
2037
+
2038
+ @staticmethod
2039
+ def _split_lu (lu , m , n ):
2040
+ L = dpnp .tril (lu , k = - 1 )
2041
+ dpnp .fill_diagonal (L , 1 )
2042
+ L = L [:, : min (m , n )]
2043
+ U = dpnp .triu (lu )[: min (m , n ), :]
2044
+ return L , U
2045
+
2046
+ @pytest .mark .parametrize (
2047
+ "shape" ,
2048
+ [(2 , 2 , 2 ), (3 , 4 , 4 ), (2 , 3 , 5 , 2 ), (4 , 1 , 3 )],
2049
+ ids = ["(2,2,2)" , "(3,4,4)" , "(2,3,5,2)" , "(4,1,3)" ],
2050
+ )
2051
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2052
+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
2053
+ def test_lu_factor_batched (self , shape , order , dtype ):
2054
+ a_np = generate_random_numpy_array (shape , dtype , order )
2055
+ a_dp = dpnp .array (a_np , order = order )
2056
+
2057
+ lu , piv = dpnp .linalg .lu_factor (
2058
+ a_dp , check_finite = False , overwrite_a = False
2059
+ )
2060
+
2061
+ assert lu .shape == a_dp .shape
2062
+ m , n = shape [- 2 ], shape [- 1 ]
2063
+ assert piv .shape == (* shape [:- 2 ], min (m , n ))
2064
+ assert piv .dtype == dpnp .int64
2065
+
2066
+ a_3d = a_dp .reshape ((- 1 , m , n ))
2067
+ lu_3d = lu .reshape ((- 1 , m , n ))
2068
+ piv_2d = piv .reshape ((- 1 , min (m , n )))
2069
+ for i in range (a_3d .shape [0 ]):
2070
+ L , U = self ._split_lu (lu_3d [i ], m , n )
2071
+ A_cast = a_3d [i ].astype (L .dtype , copy = False )
2072
+ PA = self ._apply_pivots_rows (A_cast , piv_2d [i ])
2073
+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2074
+
2075
+ @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
2076
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
2077
+ def test_overwrite (self , dtype , order ):
2078
+ a_dp = dpnp .arange (2 * 2 * 3 , dtype = dtype ).reshape (3 , 2 , 2 , order = order )
2079
+ a_dp_orig = a_dp .copy ()
2080
+ lu , piv = dpnp .linalg .lu_factor (
2081
+ a_dp , overwrite_a = True , check_finite = False
2082
+ )
2083
+
2084
+ assert lu is not a_dp
2085
+ assert_allclose (a_dp , a_dp_orig )
2086
+
2087
+ m = n = 2
2088
+ lu_3d = lu .reshape ((- 1 , m , n ))
2089
+ a_3d = a_dp .reshape ((- 1 , m , n ))
2090
+ piv_2d = piv .reshape ((- 1 , min (m , n )))
2091
+ for i in range (a_3d .shape [0 ]):
2092
+ L , U = self ._split_lu (lu_3d [i ], m , n )
2093
+ A_cast = a_3d [i ].astype (L .dtype , copy = False )
2094
+ PA = self ._apply_pivots_rows (A_cast , piv_2d [i ])
2095
+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2096
+
2097
+ @pytest .mark .parametrize (
2098
+ "shape" , [(0 , 2 , 2 ), (2 , 0 , 2 ), (2 , 2 , 0 ), (0 , 0 , 0 )]
2099
+ )
2100
+ def test_empty_inputs (self , shape ):
2101
+ a = dpnp .empty (shape , dtype = dpnp .default_float_type (), order = "F" )
2102
+
2103
+ lu , piv = dpnp .linalg .lu_factor (a , check_finite = False )
2104
+ assert lu .shape == shape
2105
+ m , n = shape [- 2 :]
2106
+ assert piv .shape == (* shape [:- 2 ], min (m , n ))
2107
+
2108
+ def test_strided (self ):
2109
+ a = (
2110
+ dpnp .arange (5 * 3 * 3 , dtype = dpnp .default_float_type ()).reshape (
2111
+ 5 , 3 , 3 , order = "F"
2112
+ )
2113
+ + 0.1
2114
+ )
2115
+ a_stride = a [::2 ]
2116
+ lu , piv = dpnp .linalg .lu_factor (a_stride , check_finite = False )
2117
+ for i in range (a_stride .shape [0 ]):
2118
+ L , U = self ._split_lu (lu [i ], 3 , 3 )
2119
+ PA = self ._apply_pivots_rows (
2120
+ a_stride [i ].astype (L .dtype , copy = False ), piv [i ]
2121
+ )
2122
+ assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2123
+
2124
+ def test_singular_matrix (self ):
2125
+ a = dpnp .zeros ((3 , 2 , 2 ), dtype = dpnp .float64 )
2126
+ a [0 ] = dpnp .array ([[1.0 , 2.0 ], [2.0 , 4.0 ]])
2127
+ a [1 ] = dpnp .eye (2 )
2128
+ a [2 ] = dpnp .array ([[1.0 , 1.0 ], [1.0 , 1.0 ]])
2129
+ with pytest .warns (RuntimeWarning , match = "Singular matrix" ):
2130
+ dpnp .linalg .lu_factor (a , check_finite = False )
2131
+
2132
+ def test_check_finite_raises (self ):
2133
+ a = dpnp .ones ((2 , 3 , 3 ), dtype = dpnp .float64 , order = "F" )
2134
+ a [1 , 0 , 0 ] = dpnp .nan
2135
+ assert_raises (ValueError , dpnp .linalg .lu_factor , a , check_finite = True )
2136
+
2137
+
2025
2138
class TestMatrixPower :
2026
2139
@pytest .mark .parametrize ("dtype" , get_all_dtypes ())
2027
2140
@pytest .mark .parametrize (
0 commit comments