@@ -2022,13 +2022,17 @@ class TestLuFactorBatched:
2022
2022
@staticmethod
2023
2023
def _apply_pivots_rows (A_dp , piv_dp ):
2024
2024
m = A_dp .shape [0 ]
2025
- rows = dpnp .arange (m )
2026
- for i in range (int (piv_dp .shape [0 ])):
2027
- r = int (piv_dp [i ].item ())
2025
+
2026
+ if m == 0 or piv_dp .size == 0 :
2027
+ return A_dp
2028
+
2029
+ rows = list (range (m ))
2030
+ piv_np = dpnp .asnumpy (piv_dp )
2031
+ for i , r in enumerate (piv_np ):
2028
2032
if i != r :
2029
- tmp = rows [i ]. copy ()
2030
- rows [ i ] = rows [ r ]
2031
- rows [ r ] = tmp
2033
+ rows [ i ], rows [ r ] = rows [r ], rows [ i ]
2034
+
2035
+ rows = dpnp . asarray ( rows )
2032
2036
return A_dp [rows ]
2033
2037
2034
2038
@staticmethod
@@ -2118,15 +2122,15 @@ def test_strided(self):
2118
2122
assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
2119
2123
2120
2124
def test_singular_matrix (self ):
2121
- a = dpnp .zeros ((3 , 2 , 2 ), dtype = dpnp .float64 )
2125
+ a = dpnp .zeros ((3 , 2 , 2 ), dtype = dpnp .default_float_type () )
2122
2126
a [0 ] = dpnp .array ([[1.0 , 2.0 ], [2.0 , 4.0 ]])
2123
2127
a [1 ] = dpnp .eye (2 )
2124
2128
a [2 ] = dpnp .array ([[1.0 , 1.0 ], [1.0 , 1.0 ]])
2125
2129
with pytest .warns (RuntimeWarning , match = "Singular matrix" ):
2126
2130
dpnp .linalg .lu_factor (a , check_finite = False )
2127
2131
2128
2132
def test_check_finite_raises (self ):
2129
- a = dpnp .ones ((2 , 3 , 3 ), dtype = dpnp .float64 , order = "F" )
2133
+ a = dpnp .ones ((2 , 3 , 3 ), dtype = dpnp .default_float_type () , order = "F" )
2130
2134
a [1 , 0 , 0 ] = dpnp .nan
2131
2135
assert_raises (ValueError , dpnp .linalg .lu_factor , a , check_finite = True )
2132
2136
0 commit comments