File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed
Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -420,10 +420,12 @@ def get_kernel_bin(self):
420420 return "spvbin"
421421
422422 def get_architecture_descriptor (self , ** kwargs ):
423- dev_props = self .driver .utils .get_device_properties (torch .xpu .device (torch .xpu .current_device ()).sycl_device ) # noqa: E501
424- max_work_group_size = dev_props ['max_work_group_size' ]
425- max_num_sub_groups = dev_props ['max_num_sub_groups' ]
426- sub_group_sizes = dev_props ['sub_group_sizes' ]
423+ arch = kwargs .get ("arch" , None )
424+ if arch is None :
425+ arch = self .get_device_properties (self .get_current_device ())
426+ max_work_group_size = arch ['max_work_group_size' ]
427+ max_num_sub_groups = arch ['max_num_sub_groups' ]
428+ sub_group_sizes = arch ['sub_group_sizes' ]
427429 # TODO: chose a reasonable subgroup size
428430 threads_per_warp = 32
429431 assert threads_per_warp in sub_group_sizes , "Current platform does not support threads_per_warp to be 32" # noqa: E501
You can’t perform that action at this time.
0 commit comments