@@ -556,3 +556,70 @@ def test_as_f_contig_square(dt):
556556 x3 = dpt .flip (x , axis = 1 )
557557 y3 = dpt .asarray (x3 , order = "F" )
558558 assert dpt .all (x3 == y3 )
559+
560+
561+ class MockArrayWithBothProtocols :
562+ """
563+ Object that implements both __sycl_usm_array_interface__
564+ and __usm_ndarray__ properties.
565+ """
566+
567+ def __init__ (self , usm_ar ):
568+ if not isinstance (usm_ar , dpt .usm_ndarray ):
569+ raise TypeError
570+ self ._arr = usm_ar
571+
572+ @property
573+ def __usm_ndarray__ (self ):
574+ return self ._arr
575+
576+ @property
577+ def __sycl_usm_array_interface__ (self ):
578+ return self ._arr .__sycl_usm_array_interface__
579+
580+
581+ class MockArrayWithSUAIOnly :
582+ """
583+ Object that implements only the
584+ __sycl_usm_array_interface__ property.
585+ """
586+
587+ def __init__ (self , usm_ar ):
588+ if not isinstance (usm_ar , dpt .usm_ndarray ):
589+ raise TypeError
590+ self ._arr = usm_ar
591+
592+ @property
593+ def __sycl_usm_array_interface__ (self ):
594+ return self ._arr .__sycl_usm_array_interface__
595+
596+
597+ @pytest .mark .parametrize ("usm_type" , ["shared" , "device" , "host" ])
598+ def test_asarray_support_for_usm_ndarray_protocol (usm_type ):
599+ get_queue_or_skip ()
600+
601+ x = dpt .arange (256 , dtype = "i4" , usm_type = usm_type )
602+
603+ o1 = MockArrayWithBothProtocols (x )
604+ o2 = MockArrayWithSUAIOnly (x )
605+
606+ y1 = dpt .asarray (o1 )
607+ assert x .sycl_queue == y1 .sycl_queue
608+ assert x .usm_type == y1 .usm_type
609+ assert x .dtype == y1 .dtype
610+ assert y1 .usm_data .reference_obj is None
611+ assert dpt .all (x == y1 )
612+
613+ y2 = dpt .asarray (o2 )
614+ assert x .sycl_queue == y2 .sycl_queue
615+ assert x .usm_type == y2 .usm_type
616+ assert x .dtype == y2 .dtype
617+ assert not (y2 .usm_data .reference_obj is None )
618+ assert dpt .all (x == y2 )
619+
620+ y3 = dpt .asarray ([o1 , o2 ])
621+ assert x .sycl_queue == y3 .sycl_queue
622+ assert x .usm_type == y3 .usm_type
623+ assert x .dtype == y3 .dtype
624+ assert y3 .usm_data .reference_obj is None
625+ assert dpt .all (x [dpt .newaxis , :] == y3 )
0 commit comments