@@ -2380,12 +2380,6 @@ class TestQr:
23802380 )
23812381 @pytest .mark .parametrize ("mode" , ["r" , "raw" , "complete" , "reduced" ])
23822382 def test_qr (self , dtype , shape , mode ):
2383- if (
2384- is_cuda_device ()
2385- and mode in ["complete" , "reduced" ]
2386- and shape in [(16 , 16 ), (2 , 2 , 4 )]
2387- ):
2388- pytest .skip ("SAT-7589" )
23892383 a = generate_random_numpy_array (shape , dtype , seed_value = 81 )
23902384 ia = dpnp .array (a )
23912385
@@ -2398,24 +2392,48 @@ def test_qr(self, dtype, shape, mode):
23982392
23992393 # check decomposition
24002394 if mode in ("complete" , "reduced" ):
2401- if a .ndim == 2 :
2402- assert_almost_equal (
2403- dpnp .dot (dpnp_q , dpnp_r ),
2404- a ,
2405- decimal = 5 ,
2406- )
2407- else : # a.ndim > 2
2408- assert_almost_equal (
2409- dpnp .matmul (dpnp_q , dpnp_r ),
2410- a ,
2411- decimal = 5 ,
2412- )
2395+ assert_almost_equal (
2396+ dpnp .matmul (dpnp_q , dpnp_r ),
2397+ a ,
2398+ decimal = 5 ,
2399+ )
24132400 else : # mode=="raw"
24142401 assert_dtype_allclose (dpnp_q , np_q )
24152402
24162403 if mode in ("raw" , "r" ):
24172404 assert_dtype_allclose (dpnp_r , np_r )
24182405
2406+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
2407+ @pytest .mark .parametrize (
2408+ "shape" ,
2409+ [(32 , 32 ), (8 , 16 , 16 )],
2410+ ids = [
2411+ "(32, 32)" ,
2412+ "(8, 16, 16)" ,
2413+ ],
2414+ )
2415+ @pytest .mark .parametrize ("mode" , ["r" , "raw" , "complete" , "reduced" ])
2416+ def test_qr_large (self , dtype , shape , mode ):
2417+ a = generate_random_numpy_array (shape , dtype , seed_value = 81 )
2418+ ia = dpnp .array (a )
2419+ if mode == "r" :
2420+ np_r = numpy .linalg .qr (a , mode )
2421+ dpnp_r = dpnp .linalg .qr (ia , mode )
2422+ else :
2423+ np_q , np_r = numpy .linalg .qr (a , mode )
2424+ dpnp_q , dpnp_r = dpnp .linalg .qr (ia , mode )
2425+ # check decomposition
2426+ if mode in ("complete" , "reduced" ):
2427+ assert_almost_equal (
2428+ dpnp .matmul (dpnp_q , dpnp_r ),
2429+ a ,
2430+ decimal = 5 ,
2431+ )
2432+ else : # mode=="raw"
2433+ assert_dtype_allclose (dpnp_q , np_q , factor = 12 )
2434+ if mode in ("raw" , "r" ):
2435+ assert_dtype_allclose (dpnp_r , np_r , factor = 12 )
2436+
24192437 @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
24202438 @pytest .mark .parametrize (
24212439 "shape" ,
0 commit comments