@@ -1932,23 +1932,66 @@ def test_svd(shape, full_matrices, compute_uv, device):
1932
1932
assert_sycl_queue_equal (dpnp_s_queue , expected_queue )
1933
1933
1934
1934
1935
- @pytest .mark .parametrize (
1936
- "device_from" ,
1937
- valid_devices ,
1938
- ids = [device .filter_string for device in valid_devices ],
1939
- )
1940
- @pytest .mark .parametrize (
1941
- "device_to" ,
1942
- valid_devices ,
1943
- ids = [device .filter_string for device in valid_devices ],
1944
- )
1945
- def test_to_device (device_from , device_to ):
1946
- data = [1.0 , 1.0 , 1.0 , 1.0 , 1.0 ]
1947
-
1948
- x = dpnp .array (data , dtype = dpnp .float32 , device = device_from )
1949
- y = x .to_device (device_to )
1935
+ class TestToDevice :
1936
+ @pytest .mark .parametrize (
1937
+ "device_from" ,
1938
+ valid_devices ,
1939
+ ids = [device .filter_string for device in valid_devices ],
1940
+ )
1941
+ @pytest .mark .parametrize (
1942
+ "device_to" ,
1943
+ valid_devices ,
1944
+ ids = [device .filter_string for device in valid_devices ],
1945
+ )
1946
+ def test_basic (self , device_from , device_to ):
1947
+ data = [1.0 , 1.0 , 1.0 , 1.0 , 1.0 ]
1948
+ x = dpnp .array (data , dtype = dpnp .float32 , device = device_from )
1949
+
1950
+ y = x .to_device (device_to )
1951
+ assert y .sycl_device == device_to
1952
+ assert (x .asnumpy () == y .asnumpy ()).all ()
1953
+
1954
+ def test_to_queue (self ):
1955
+ x = dpnp .full (100 , 2 , dtype = dpnp .int64 )
1956
+ q_prof = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1957
+
1958
+ y = x .to_device (q_prof )
1959
+ assert (x .asnumpy () == y .asnumpy ()).all ()
1960
+ assert_sycl_queue_equal (y .sycl_queue , q_prof )
1961
+
1962
+ def test_stream (self ):
1963
+ x = dpnp .full (100 , 2 , dtype = dpnp .int64 )
1964
+ q_prof = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1965
+ q_exec = dpctl .SyclQueue (x .sycl_device )
1966
+
1967
+ y = x .to_device (q_prof , stream = q_exec )
1968
+ assert (x .asnumpy () == y .asnumpy ()).all ()
1969
+ assert_sycl_queue_equal (y .sycl_queue , q_prof )
1970
+
1971
+ q_exec = dpctl .SyclQueue (x .sycl_device )
1972
+ _ = dpnp .linspace (0 , 20 , num = 10 ** 5 , sycl_queue = q_exec )
1973
+ y = x .to_device (q_prof , stream = q_exec )
1974
+ assert (x .asnumpy () == y .asnumpy ()).all ()
1975
+ assert_sycl_queue_equal (y .sycl_queue , q_prof )
1976
+
1977
+ def test_stream_no_sync (self ):
1978
+ x = dpnp .full (100 , 2 , dtype = dpnp .int64 )
1979
+ q_prof = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1980
+
1981
+ for stream in [None , x .sycl_queue ]:
1982
+ y = x .to_device (q_prof , stream = stream )
1983
+ assert (x .asnumpy () == y .asnumpy ()).all ()
1984
+ assert_sycl_queue_equal (y .sycl_queue , q_prof )
1950
1985
1951
- assert y .sycl_device == device_to
1986
+ @pytest .mark .parametrize (
1987
+ "stream" ,
1988
+ [1 , dict (), dpctl .SyclDevice ()],
1989
+ ids = ["scalar" , "dictionary" , "device" ],
1990
+ )
1991
+ def test_invalid_stream (self , stream ):
1992
+ x = dpnp .ones (2 , dtype = dpnp .int64 )
1993
+ q_prof = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1994
+ assert_raises (TypeError , x .to_device , q_prof , stream = stream )
1952
1995
1953
1996
1954
1997
@pytest .mark .parametrize (
0 commit comments