@@ -826,3 +826,38 @@ def test_generic_container():
826826 assert isinstance (Z , dpt .usm_ndarray )
827827 assert Z ._pointer == X ._pointer
828828 assert Z .device == X .device
829+
830+
831+ def test_sycldevice_to_dldevice (all_root_devices ):
832+ for sycl_dev in all_root_devices :
833+ dev = dpt .sycldevice_to_dldevice (sycl_dev )
834+ assert type (dev ) is tuple
835+ assert len (dev ) == 2
836+ assert dev [0 ] == device_oneAPI
837+ assert dev [1 ] == all_root_devices .index (sycl_dev )
838+
839+
840+ def test_dldevice_to_sycldevice (all_root_devices ):
841+ for sycl_dev in all_root_devices :
842+ dldev = dpt .empty (0 , device = sycl_dev ).__dlpack_device__ ()
843+ dev = dpt .dldevice_to_sycldevice (dldev )
844+ assert type (dev ) is dpctl .SyclDevice
845+ assert dev == all_root_devices [dldev [1 ]]
846+
847+
848+ def test_dldevice_conversion_arg_validation ():
849+ bad_dldevice_type = (dpt .DLDeviceType .kDLCPU , 0 )
850+ with pytest .raises (ValueError ):
851+ dpt .dldevice_to_sycldevice (bad_dldevice_type )
852+
853+ bad_dldevice_len = bad_dldevice_type + (0 ,)
854+ with pytest .raises (ValueError ):
855+ dpt .dldevice_to_sycldevice (bad_dldevice_len )
856+
857+ bad_dldevice = dict ()
858+ with pytest .raises (TypeError ):
859+ dpt .dldevice_to_sycldevice (bad_dldevice )
860+
861+ bad_sycldevice = dict ()
862+ with pytest .raises (TypeError ):
863+ dpt .sycldevice_to_dldevice (bad_sycldevice )
0 commit comments