Skip to content

Commit 6f34de4

Browse files
authored
Add Triton knob device_extensions (#4466)
This knob overrides (via environment variable `TRITON_INTEL_DEVICE_EXTENSIONS`) device extensions returned by `ocloc`. Can be used for experiments and temporary for the systems without `ocloc`. Signed-off-by: Pavel Chekin <[email protected]>
1 parent 35f2005 commit 6f34de4

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

python/triton/knobs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,11 @@ class intel_knobs(base_knobs):
501501

502502
libdevice_path: env_opt_str = env_opt_str("TRITON_LIBDEVICE_PATH")
503503

504+
# Space separated list of device extensions, similar to the output of
505+
# `ocloc query CL_DEVICE_EXTENSIONS`. If not set, a compiler calls `ocloc` in runtime to get
506+
# the actual device extensions.
507+
device_extensions: env_opt_str = env_opt_str("TRITON_INTEL_DEVICE_EXTENSIONS")
508+
504509

505510
class amd_knobs(base_knobs):
506511
use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS", True)

third_party/intel/backend/compiler.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -140,29 +140,37 @@ def parse_target(self, tgt_prop) -> dict:
140140
dev_prop['has_subgroup_2d_block_io'] = tgt_prop.get('has_subgroup_2d_block_io', False)
141141
dev_prop['has_bfloat16_conversions'] = tgt_prop.get('has_bfloat16_conversions', True)
142142

143-
if self.device_arch and shutil.which('ocloc'):
144-
if self.device_arch in self.device_props:
145-
dev_prop.update(self.device_props[self.device_arch])
146-
return dev_prop
143+
if not self.device_arch:
144+
return dev_prop
145+
146+
if self.device_arch in self.device_props:
147+
dev_prop.update(self.device_props[self.device_arch])
148+
return dev_prop
149+
150+
supported_extensions = set()
151+
152+
if knobs.intel.device_extensions:
153+
supported_extensions.update(knobs.intel.device_extensions.split(' '))
154+
elif shutil.which('ocloc'):
147155
try:
148-
ocloc_cmd = ['ocloc', 'query', 'CL_DEVICE_EXTENSIONS', '-device', self.device_arch]
156+
cmd = ['ocloc', 'query', 'CL_DEVICE_EXTENSIONS', '-device', self.device_arch]
149157
with tempfile.TemporaryDirectory() as temp_dir:
150-
output = subprocess.check_output(ocloc_cmd, text=True, cwd=temp_dir)
151-
supported_extensions = set()
152-
for extension in output.split(' '):
153-
supported_extensions.add(extension)
154-
ocloc_dev_prop = {}
155-
ocloc_dev_prop[
156-
'has_subgroup_matrix_multiply_accumulate'] = 'cl_intel_subgroup_matrix_multiply_accumulate' in supported_extensions
157-
ocloc_dev_prop[
158-
'has_subgroup_matrix_multiply_accumulate_tensor_float32'] = 'cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32' in supported_extensions
159-
ocloc_dev_prop['has_subgroup_2d_block_io'] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
160-
ocloc_dev_prop['has_bfloat16_conversions'] = 'cl_intel_bfloat16_conversions' in supported_extensions
161-
self.device_props[self.device_arch] = ocloc_dev_prop
162-
dev_prop.update(ocloc_dev_prop)
158+
output = subprocess.check_output(cmd, text=True, cwd=temp_dir)
159+
supported_extensions.update(output.split(' '))
163160
except subprocess.CalledProcessError:
164161
# Note: LTS driver does not support ocloc query CL_DEVICE_EXTENSIONS.
165162
pass
163+
164+
ocloc_dev_prop = {}
165+
ocloc_dev_prop[
166+
'has_subgroup_matrix_multiply_accumulate'] = 'cl_intel_subgroup_matrix_multiply_accumulate' in supported_extensions
167+
ocloc_dev_prop[
168+
'has_subgroup_matrix_multiply_accumulate_tensor_float32'] = 'cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32' in supported_extensions
169+
ocloc_dev_prop['has_subgroup_2d_block_io'] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
170+
ocloc_dev_prop['has_bfloat16_conversions'] = 'cl_intel_bfloat16_conversions' in supported_extensions
171+
self.device_props[self.device_arch] = ocloc_dev_prop
172+
dev_prop.update(ocloc_dev_prop)
173+
166174
return dev_prop
167175

168176
def parse_options(self, opts) -> Any:

0 commit comments

Comments
 (0)