@@ -140,29 +140,37 @@ def parse_target(self, tgt_prop) -> dict:
140140 dev_prop ['has_subgroup_2d_block_io' ] = tgt_prop .get ('has_subgroup_2d_block_io' , False )
141141 dev_prop ['has_bfloat16_conversions' ] = tgt_prop .get ('has_bfloat16_conversions' , True )
142142
143- if self .device_arch and shutil .which ('ocloc' ):
144- if self .device_arch in self .device_props :
145- dev_prop .update (self .device_props [self .device_arch ])
146- return dev_prop
143+ if not self .device_arch :
144+ return dev_prop
145+
146+ if self .device_arch in self .device_props :
147+ dev_prop .update (self .device_props [self .device_arch ])
148+ return dev_prop
149+
150+ supported_extensions = set ()
151+
152+ if knobs .intel .device_extensions :
153+ supported_extensions .update (knobs .intel .device_extensions .split (' ' ))
154+ elif shutil .which ('ocloc' ):
147155 try :
148- ocloc_cmd = ['ocloc' , 'query' , 'CL_DEVICE_EXTENSIONS' , '-device' , self .device_arch ]
156+ cmd = ['ocloc' , 'query' , 'CL_DEVICE_EXTENSIONS' , '-device' , self .device_arch ]
149157 with tempfile .TemporaryDirectory () as temp_dir :
150- output = subprocess .check_output (ocloc_cmd , text = True , cwd = temp_dir )
151- supported_extensions = set ()
152- for extension in output .split (' ' ):
153- supported_extensions .add (extension )
154- ocloc_dev_prop = {}
155- ocloc_dev_prop [
156- 'has_subgroup_matrix_multiply_accumulate' ] = 'cl_intel_subgroup_matrix_multiply_accumulate' in supported_extensions
157- ocloc_dev_prop [
158- 'has_subgroup_matrix_multiply_accumulate_tensor_float32' ] = 'cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32' in supported_extensions
159- ocloc_dev_prop ['has_subgroup_2d_block_io' ] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
160- ocloc_dev_prop ['has_bfloat16_conversions' ] = 'cl_intel_bfloat16_conversions' in supported_extensions
161- self .device_props [self .device_arch ] = ocloc_dev_prop
162- dev_prop .update (ocloc_dev_prop )
158+ output = subprocess .check_output (cmd , text = True , cwd = temp_dir )
159+ supported_extensions .update (output .split (' ' ))
163160 except subprocess .CalledProcessError :
164161 # Note: LTS driver does not support ocloc query CL_DEVICE_EXTENSIONS.
165162 pass
163+
164+ ocloc_dev_prop = {}
165+ ocloc_dev_prop [
166+ 'has_subgroup_matrix_multiply_accumulate' ] = 'cl_intel_subgroup_matrix_multiply_accumulate' in supported_extensions
167+ ocloc_dev_prop [
168+ 'has_subgroup_matrix_multiply_accumulate_tensor_float32' ] = 'cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32' in supported_extensions
169+ ocloc_dev_prop ['has_subgroup_2d_block_io' ] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
170+ ocloc_dev_prop ['has_bfloat16_conversions' ] = 'cl_intel_bfloat16_conversions' in supported_extensions
171+ self .device_props [self .device_arch ] = ocloc_dev_prop
172+ dev_prop .update (ocloc_dev_prop )
173+
166174 return dev_prop
167175
168176 def parse_options (self , opts ) -> Any :
0 commit comments