Skip to content

Commit b70c7f7

Browse files
Parse architecture from PyTorch instead of hard coding (#2995)
With pytorch/pytorch#138186, `architecture` is added to XPU device property. Instead of hard coding `pvc` when invoking `ocloc`, this PR changed to dynamically passing the device architecture parsed. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 2e7a5de commit b70c7f7

File tree

2 files changed

+87
-18
lines changed

2 files changed

+87
-18
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//===- arch_parser.c ------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <sycl/sycl.hpp>
10+
11+
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
12+
#include <Python.h>
13+
#include <numpy/arrayobject.h>
14+
15+
static PyObject *parseDeviceArch(PyObject *self, PyObject *args) {
16+
uint64_t dev_arch;
17+
assert(PyArg_ParseTuple(args, "K", &dev_arch) && "Expected an integer");
18+
19+
sycl::ext::oneapi::experimental::architecture sycl_arch =
20+
static_cast<sycl::ext::oneapi::experimental::architecture>(dev_arch);
21+
// FIXME: Add support for more architectures.
22+
std::string arch = "";
23+
switch (sycl_arch) {
24+
case sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc:
25+
arch = "pvc";
26+
break;
27+
case sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21:
28+
arch = "bmg";
29+
break;
30+
case sycl::ext::oneapi::experimental::architecture::intel_gpu_lnl_m:
31+
arch = "lnl";
32+
break;
33+
default:
34+
printf("sycl_arch = %d", sycl_arch);
35+
}
36+
37+
return Py_BuildValue("s", arch.c_str());
38+
}
39+
40+
static PyMethodDef ModuleMethods[] = {
41+
{"parse_device_arch", parseDeviceArch, METH_VARARGS,
42+
"parse device architecture"},
43+
{NULL, NULL, 0, NULL} // sentinel
44+
};
45+
46+
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "arch_utils",
47+
NULL, // documentation
48+
-1, // size
49+
ModuleMethods};
50+
51+
PyMODINIT_FUNC PyInit_arch_utils(void) {
52+
if (PyObject *m = PyModule_Create(&ModuleDef)) {
53+
PyModule_AddFunctions(m, ModuleMethods);
54+
return m;
55+
}
56+
return NULL;
57+
}

third_party/intel/backend/compiler.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from triton.backends.compiler import BaseBackend
22
from triton._C.libtriton import ir, passes, llvm, intel
3+
from triton.backends.intel.driver import compile_module_from_src
34

45
from dataclasses import dataclass
56
import functools
@@ -96,6 +97,7 @@ def get_ops_per_channel(lhs_type, rhs_type):
9697

9798

9899
class XPUBackend(BaseBackend):
100+
device_props: dict = {}
99101

100102
# AdvancedPath pass pipeline for kernels using block pointers.
101103
class AdvancedPath:
@@ -127,6 +129,9 @@ def __init__(self, target: tuple) -> None:
127129
super().__init__(target)
128130
if not isinstance(target.arch, dict):
129131
raise TypeError("target.arch is not a dict")
132+
dirname = os.path.dirname(os.path.realpath(__file__))
133+
mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils")
134+
self.parse_device_arch = mod.parse_device_arch
130135
self.properties = self.parse_target(target.arch)
131136
self.binary_ext = "spv"
132137

@@ -142,30 +147,37 @@ def parse_target(self, tgt_prop) -> dict:
142147
dev_prop['max_num_sub_groups'] = tgt_prop.get('max_num_sub_groups', None)
143148
dev_prop['sub_group_sizes'] = tgt_prop.get('sub_group_sizes', None)
144149
dev_prop['has_fp64'] = tgt_prop.get('has_fp64', None)
145-
if os.getenv("TRITON_INTEL_QUERY_DEVICE_EXTENSIONS", "0") == "1":
150+
dev_prop['has_subgroup_matrix_multiply_accumulate'] = tgt_prop.get('has_subgroup_matrix_multiply_accumulate',
151+
False)
152+
dev_prop['has_subgroup_matrix_multiply_accumulate_tensor_float32'] = tgt_prop.get(
153+
'has_subgroup_matrix_multiply_accumulate_tensor_float32', False)
154+
dev_prop['has_subgroup_2d_block_io'] = tgt_prop.get('has_subgroup_2d_block_io', False)
155+
dev_prop['has_bfloat16_conversions'] = tgt_prop.get('has_bfloat16_conversions', True)
156+
157+
device_arch = self.parse_device_arch(tgt_prop.get('architecture', 0))
158+
if device_arch:
159+
if device_arch in self.device_props:
160+
dev_prop.update(self.device_props[device_arch])
161+
return dev_prop
146162
try:
147-
# FIXME: Add support for other devices.
148-
ocloc_cmd = ['ocloc', 'query', 'CL_DEVICE_EXTENSIONS', '-device', 'pvc']
149-
result = subprocess.run(ocloc_cmd, check=True, capture_output=True, text=True)
150-
output = result.stdout
163+
ocloc_cmd = ['ocloc', 'query', 'CL_DEVICE_EXTENSIONS', '-device', device_arch]
164+
with tempfile.TemporaryDirectory() as temp_dir:
165+
output = subprocess.check_output(ocloc_cmd, text=True, cwd=temp_dir)
151166
supported_extensions = set()
152167
for extension in output.split(' '):
153168
supported_extensions.add(extension)
154-
dev_prop[
169+
ocloc_dev_prop = {}
170+
ocloc_dev_prop[
155171
'has_subgroup_matrix_multiply_accumulate'] = 'cl_intel_subgroup_matrix_multiply_accumulate' in supported_extensions
156-
dev_prop[
172+
ocloc_dev_prop[
157173
'has_subgroup_matrix_multiply_accumulate_tensor_float32'] = 'cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32' in supported_extensions
158-
dev_prop['has_subgroup_2d_block_io'] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
159-
dev_prop['has_bfloat16_conversions'] = 'cl_intel_bfloat16_conversions' in supported_extensions
160-
except subprocess.CalledProcessError as e:
161-
raise RuntimeError(f'`ocloc` failed with error code {e.returncode}')
162-
else:
163-
dev_prop['has_subgroup_matrix_multiply_accumulate'] = tgt_prop.get(
164-
'has_subgroup_matrix_multiply_accumulate', False)
165-
dev_prop['has_subgroup_matrix_multiply_accumulate_tensor_float32'] = tgt_prop.get(
166-
'has_subgroup_matrix_multiply_accumulate_tensor_float32', False)
167-
dev_prop['has_subgroup_2d_block_io'] = tgt_prop.get('has_subgroup_2d_block_io', False)
168-
dev_prop['has_bfloat16_conversions'] = tgt_prop.get('has_bfloat16_conversions', True)
174+
ocloc_dev_prop['has_subgroup_2d_block_io'] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
175+
ocloc_dev_prop['has_bfloat16_conversions'] = 'cl_intel_bfloat16_conversions' in supported_extensions
176+
self.device_props[device_arch] = ocloc_dev_prop
177+
dev_prop.update(ocloc_dev_prop)
178+
except subprocess.CalledProcessError:
179+
# Note: LTS driver does not support ocloc query CL_DEVICE_EXTENSIONS.
180+
pass
169181
return dev_prop
170182

171183
def parse_options(self, opts) -> Any:

0 commit comments

Comments
 (0)