@@ -140,29 +140,37 @@ def parse_target(self, tgt_prop) -> dict:
140
140
dev_prop ['has_subgroup_2d_block_io' ] = tgt_prop .get ('has_subgroup_2d_block_io' , False )
141
141
dev_prop ['has_bfloat16_conversions' ] = tgt_prop .get ('has_bfloat16_conversions' , True )
142
142
143
- if self .device_arch and shutil .which ('ocloc' ):
144
- if self .device_arch in self .device_props :
145
- dev_prop .update (self .device_props [self .device_arch ])
146
- return dev_prop
143
+ if not self .device_arch :
144
+ return dev_prop
145
+
146
+ if self .device_arch in self .device_props :
147
+ dev_prop .update (self .device_props [self .device_arch ])
148
+ return dev_prop
149
+
150
+ supported_extensions = set ()
151
+
152
+ if knobs .intel .device_extensions :
153
+ supported_extensions .update (knobs .intel .device_extensions .split (' ' ))
154
+ elif shutil .which ('ocloc' ):
147
155
try :
148
- ocloc_cmd = ['ocloc' , 'query' , 'CL_DEVICE_EXTENSIONS' , '-device' , self .device_arch ]
156
+ cmd = ['ocloc' , 'query' , 'CL_DEVICE_EXTENSIONS' , '-device' , self .device_arch ]
149
157
with tempfile .TemporaryDirectory () as temp_dir :
150
- output = subprocess .check_output (ocloc_cmd , text = True , cwd = temp_dir )
151
- supported_extensions = set ()
152
- for extension in output .split (' ' ):
153
- supported_extensions .add (extension )
154
- ocloc_dev_prop = {}
155
- ocloc_dev_prop [
156
- 'has_subgroup_matrix_multiply_accumulate' ] = 'cl_intel_subgroup_matrix_multiply_accumulate' in supported_extensions
157
- ocloc_dev_prop [
158
- 'has_subgroup_matrix_multiply_accumulate_tensor_float32' ] = 'cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32' in supported_extensions
159
- ocloc_dev_prop ['has_subgroup_2d_block_io' ] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
160
- ocloc_dev_prop ['has_bfloat16_conversions' ] = 'cl_intel_bfloat16_conversions' in supported_extensions
161
- self .device_props [self .device_arch ] = ocloc_dev_prop
162
- dev_prop .update (ocloc_dev_prop )
158
+ output = subprocess .check_output (cmd , text = True , cwd = temp_dir )
159
+ supported_extensions .update (output .split (' ' ))
163
160
except subprocess .CalledProcessError :
164
161
# Note: LTS driver does not support ocloc query CL_DEVICE_EXTENSIONS.
165
162
pass
163
+
164
+ ocloc_dev_prop = {}
165
+ ocloc_dev_prop [
166
+ 'has_subgroup_matrix_multiply_accumulate' ] = 'cl_intel_subgroup_matrix_multiply_accumulate' in supported_extensions
167
+ ocloc_dev_prop [
168
+ 'has_subgroup_matrix_multiply_accumulate_tensor_float32' ] = 'cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32' in supported_extensions
169
+ ocloc_dev_prop ['has_subgroup_2d_block_io' ] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
170
+ ocloc_dev_prop ['has_bfloat16_conversions' ] = 'cl_intel_bfloat16_conversions' in supported_extensions
171
+ self .device_props [self .device_arch ] = ocloc_dev_prop
172
+ dev_prop .update (ocloc_dev_prop )
173
+
166
174
return dev_prop
167
175
168
176
def parse_options (self , opts ) -> Any :
0 commit comments