@@ -2022,13 +2022,17 @@ class TestLuFactorBatched:
20222022 @staticmethod
20232023 def _apply_pivots_rows (A_dp , piv_dp ):
20242024 m = A_dp .shape [0 ]
2025- rows = dpnp .arange (m )
2026- for i in range (int (piv_dp .shape [0 ])):
2027- r = int (piv_dp [i ].item ())
2025+
2026+ if m == 0 or piv_dp .size == 0 :
2027+ return A_dp
2028+
2029+ rows = list (range (m ))
2030+ piv_np = dpnp .asnumpy (piv_dp )
2031+ for i , r in enumerate (piv_np ):
20282032 if i != r :
2029- tmp = rows [i ]. copy ()
2030- rows [ i ] = rows [ r ]
2031- rows [ r ] = tmp
2033+ rows [ i ], rows [ r ] = rows [r ], rows [ i ]
2034+
2035+ rows = dpnp . asarray ( rows )
20322036 return A_dp [rows ]
20332037
20342038 @staticmethod
@@ -2118,15 +2122,15 @@ def test_strided(self):
21182122 assert_allclose (L @ U , PA , rtol = 1e-6 , atol = 1e-6 )
21192123
21202124 def test_singular_matrix (self ):
2121- a = dpnp .zeros ((3 , 2 , 2 ), dtype = dpnp .float64 )
2125+ a = dpnp .zeros ((3 , 2 , 2 ), dtype = dpnp .default_float_type () )
21222126 a [0 ] = dpnp .array ([[1.0 , 2.0 ], [2.0 , 4.0 ]])
21232127 a [1 ] = dpnp .eye (2 )
21242128 a [2 ] = dpnp .array ([[1.0 , 1.0 ], [1.0 , 1.0 ]])
21252129 with pytest .warns (RuntimeWarning , match = "Singular matrix" ):
21262130 dpnp .linalg .lu_factor (a , check_finite = False )
21272131
21282132 def test_check_finite_raises (self ):
2129- a = dpnp .ones ((2 , 3 , 3 ), dtype = dpnp .float64 , order = "F" )
2133+ a = dpnp .ones ((2 , 3 , 3 ), dtype = dpnp .default_float_type () , order = "F" )
21302134 a [1 , 0 , 0 ] = dpnp .nan
21312135 assert_raises (ValueError , dpnp .linalg .lu_factor , a , check_finite = True )
21322136
0 commit comments