Skip to content

Commit 69998dd

Browse files
authored
use kwarg arch from ipex (#105) (#107)
* use cc from ipex * refine code (cherry picked from commit c3ecd6d)
1 parent 7fc9f9b commit 69998dd

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

python/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,12 @@ def get_kernel_bin(self):
420420
return "spvbin"
421421

422422
def get_architecture_descriptor(self, **kwargs):
423-
dev_props = self.driver.utils.get_device_properties(torch.xpu.device(torch.xpu.current_device()).sycl_device) # noqa: E501
424-
max_work_group_size = dev_props['max_work_group_size']
425-
max_num_sub_groups = dev_props['max_num_sub_groups']
426-
sub_group_sizes = dev_props['sub_group_sizes']
423+
arch = kwargs.get("arch", None)
424+
if arch is None:
425+
arch = self.get_device_properties(self.get_current_device())
426+
max_work_group_size = arch['max_work_group_size']
427+
max_num_sub_groups = arch['max_num_sub_groups']
428+
sub_group_sizes = arch['sub_group_sizes']
427429
# TODO: chose a reasonable subgroup size
428430
threads_per_warp = 32
429431
assert threads_per_warp in sub_group_sizes, "Current platform does not support threads_per_warp to be 32" # noqa: E501

0 commit comments

Comments
 (0)