Skip to content

Commit 106c399

Browse files
Query device extensions using ocloc (#2466)
Before this PR, Triton relies on PyTorch to query SYCL API to determine if a device extension is supported. To lower the dependency to PyTorch, this PR uses `ocloc` to query the list of supported device extensions. It is currently under an env var `TRITON_INTEL_QUERY_DEVICE_EXTENSIONS`, as the `ocloc` command requires specifying the device type, and this PR hard coded it to PVC, so it cannot yet be enabled by default. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent d3a8eb0 commit 106c399

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

third_party/intel/backend/compiler.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,30 @@ def parse_target(self, tgt_prop) -> dict:
140140
dev_prop['max_num_sub_groups'] = tgt_prop.get('max_num_sub_groups', None)
141141
dev_prop['sub_group_sizes'] = tgt_prop.get('sub_group_sizes', None)
142142
dev_prop['has_fp64'] = tgt_prop.get('has_fp64', None)
143-
dev_prop['has_subgroup_matrix_multiply_accumulate'] = tgt_prop.get('has_subgroup_matrix_multiply_accumulate',
144-
False)
145-
dev_prop['has_subgroup_matrix_multiply_accumulate_tensor_float32'] = tgt_prop.get(
146-
'has_subgroup_matrix_multiply_accumulate_tensor_float32', False)
147-
dev_prop['has_subgroup_2d_block_io'] = tgt_prop.get('has_subgroup_2d_block_io', False)
148-
dev_prop['has_bfloat16_conversions'] = tgt_prop.get('has_bfloat16_conversions', True)
143+
if os.getenv("TRITON_INTEL_QUERY_DEVICE_EXTENSIONS", "0") == "1":
144+
try:
145+
# FIXME: Add support for other devices.
146+
ocloc_cmd = ['ocloc', 'query', 'CL_DEVICE_EXTENSIONS', '-device', 'pvc']
147+
result = subprocess.run(ocloc_cmd, check=True, capture_output=True, text=True)
148+
output = result.stdout
149+
supported_extensions = set()
150+
for extension in output.split(' '):
151+
supported_extensions.add(extension)
152+
dev_prop[
153+
'has_subgroup_matrix_multiply_accumulate'] = 'cl_intel_subgroup_matrix_multiply_accumulate' in supported_extensions
154+
dev_prop[
155+
'has_subgroup_matrix_multiply_accumulate_tensor_float32'] = 'cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32' in supported_extensions
156+
dev_prop['has_subgroup_2d_block_io'] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
157+
dev_prop['has_bfloat16_conversions'] = 'cl_intel_bfloat16_conversions' in supported_extensions
158+
except subprocess.CalledProcessError as e:
159+
raise RuntimeError(f'`ocloc` failed with error code {e.returncode}')
160+
else:
161+
dev_prop['has_subgroup_matrix_multiply_accumulate'] = tgt_prop.get(
162+
'has_subgroup_matrix_multiply_accumulate', False)
163+
dev_prop['has_subgroup_matrix_multiply_accumulate_tensor_float32'] = tgt_prop.get(
164+
'has_subgroup_matrix_multiply_accumulate_tensor_float32', False)
165+
dev_prop['has_subgroup_2d_block_io'] = tgt_prop.get('has_subgroup_2d_block_io', False)
166+
dev_prop['has_bfloat16_conversions'] = tgt_prop.get('has_bfloat16_conversions', True)
149167
return dev_prop
150168

151169
def parse_options(self, opts) -> Any:

0 commit comments

Comments
 (0)