Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions python/triton/_internal_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,83 @@ def is_xpu():
target = get_current_target()
return False if target is None else target.backend == "xpu"

def is_xpu_arl_h():
target = get_current_target()
return target is not None and target.backend == 'xpu' and 'arl_h' in target.arch


def is_xpu_arl_s():
target = get_current_target()
return target is not None and target.backend == 'xpu' and 'arl_s' in target.arch


def is_xpu_bmg():
target = get_current_target()
return target is not None and target.backend == 'xpu' and 'bmg' in target.arch


def is_xpu_dg2():
target = get_current_target()
return target is not None and target.backend == 'xpu' and 'dg2' in target.arch


def is_xpu_lnl():
target = get_current_target()
return target is not None and target.backend == 'xpu' and 'lnl' in target.arch


def is_xpu_mtl():
target = get_current_target()
return target is not None and target.backend == 'xpu' and 'mtl' in target.arch


def is_xpu_pvc():
target = get_current_target()
return target is not None and target.backend == 'xpu' and 'pvc' in target.arch


def is_xpu_ptl_h():
target = get_current_target()
return target is not None and target.backend == 'xpu' and 'ptl_h' in target.arch


def is_xpu_ptl_u():
target = get_current_target()
return target is not None and target.backend == 'xpu' and 'ptl_u' in target.arch


def is_xpu_cri():
target = get_current_target()
return target is not None and target.backend == 'xpu' and 'cri' in target.arch


def is_xpu_lpg():
return is_xpu_arl_s()


def is_xpu_lpgP():
return is_xpu_arl_h()


def is_xpu_hpg():
return is_xpu_dg2()


def is_xpu_hpc():
return is_xpu_pvc()


Comment on lines +159 to +174
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder if those are really going to be useful

def is_xpu_xe2():
return is_xpu_bmg()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is pvc xe2?



def is_xpu_xe3():
return is_xpu_ptl_h() or is_xpu_ptl_u()


def is_xpu_xe3P():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] I usually see xe3p with p lower case.

return is_xpu_cri()


def get_arch():
target = get_current_target()
Expand Down
2 changes: 1 addition & 1 deletion third_party/intel/backend/arch_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ extern "C" EXPORT_FUNC const char *parse_device_arch(uint64_t dev_arch) {
break;
#endif
case sycl::ext::oneapi::experimental::architecture::unknown:
std::cerr << "unknown sycl_arch" << std::endl;
std::cerr << "unknown" << std::endl;
break;
default:
std::cerr << "sycl_arch not recognized: " << (uint64_t)sycl_arch
Expand Down
10 changes: 10 additions & 0 deletions third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,17 @@ def update_advanced_features(device, dev_property):
dev_property["has_subgroup_2d_block_io"] = check(device, b"cl_intel_subgroup_2d_block_io")
dev_property["has_bfloat16_conversions"] = check(device, b"cl_intel_bfloat16_conversions")

def update_device_arch(dev_property):
if not (arch := knobs.intel.device_arch):
dirname = os.path.dirname(os.path.realpath(__file__))
parser = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(),
name="arch_utils")
arch_name = parser.parse_device_arch(dev_property["architecture"])
dev_property["arch"] = arch

update_advanced_features(device, dev_property)
update_device_arch(dev_property)

return GPUTarget("xpu", dev_property, warp_size=32)

def build_proton_help_lib(self):
Expand Down
Loading