11from triton .backends .compiler import BaseBackend
22from triton ._C .libtriton import ir , passes , llvm , intel
3+ from triton .backends .intel .driver import compile_module_from_src
34
45from dataclasses import dataclass
56import functools
@@ -96,6 +97,7 @@ def get_ops_per_channel(lhs_type, rhs_type):
9697
9798
9899class XPUBackend (BaseBackend ):
100+ device_props : dict = {}
99101
100102 # AdvancedPath pass pipeline for kernels using block pointers.
101103 class AdvancedPath :
@@ -127,6 +129,9 @@ def __init__(self, target: tuple) -> None:
127129 super ().__init__ (target )
128130 if not isinstance (target .arch , dict ):
129131 raise TypeError ("target.arch is not a dict" )
132+ dirname = os .path .dirname (os .path .realpath (__file__ ))
133+ mod = compile_module_from_src (Path (os .path .join (dirname , "arch_parser.c" )).read_text (), "arch_utils" )
134+ self .parse_device_arch = mod .parse_device_arch
130135 self .properties = self .parse_target (target .arch )
131136 self .binary_ext = "spv"
132137
@@ -142,30 +147,37 @@ def parse_target(self, tgt_prop) -> dict:
142147 dev_prop ['max_num_sub_groups' ] = tgt_prop .get ('max_num_sub_groups' , None )
143148 dev_prop ['sub_group_sizes' ] = tgt_prop .get ('sub_group_sizes' , None )
144149 dev_prop ['has_fp64' ] = tgt_prop .get ('has_fp64' , None )
145- if os .getenv ("TRITON_INTEL_QUERY_DEVICE_EXTENSIONS" , "0" ) == "1" :
150+ dev_prop ['has_subgroup_matrix_multiply_accumulate' ] = tgt_prop .get ('has_subgroup_matrix_multiply_accumulate' ,
151+ False )
152+ dev_prop ['has_subgroup_matrix_multiply_accumulate_tensor_float32' ] = tgt_prop .get (
153+ 'has_subgroup_matrix_multiply_accumulate_tensor_float32' , False )
154+ dev_prop ['has_subgroup_2d_block_io' ] = tgt_prop .get ('has_subgroup_2d_block_io' , False )
155+ dev_prop ['has_bfloat16_conversions' ] = tgt_prop .get ('has_bfloat16_conversions' , True )
156+
157+ device_arch = self .parse_device_arch (tgt_prop .get ('architecture' , 0 ))
158+ if device_arch :
159+ if device_arch in self .device_props :
160+ dev_prop .update (self .device_props [device_arch ])
161+ return dev_prop
146162 try :
147- # FIXME: Add support for other devices.
148- ocloc_cmd = ['ocloc' , 'query' , 'CL_DEVICE_EXTENSIONS' , '-device' , 'pvc' ]
149- result = subprocess .run (ocloc_cmd , check = True , capture_output = True , text = True )
150- output = result .stdout
163+ ocloc_cmd = ['ocloc' , 'query' , 'CL_DEVICE_EXTENSIONS' , '-device' , device_arch ]
164+ with tempfile .TemporaryDirectory () as temp_dir :
165+ output = subprocess .check_output (ocloc_cmd , text = True , cwd = temp_dir )
151166 supported_extensions = set ()
152167 for extension in output .split (' ' ):
153168 supported_extensions .add (extension )
154- dev_prop [
169+ ocloc_dev_prop = {}
170+ ocloc_dev_prop [
155171 'has_subgroup_matrix_multiply_accumulate' ] = 'cl_intel_subgroup_matrix_multiply_accumulate' in supported_extensions
156- dev_prop [
172+ ocloc_dev_prop [
157173 'has_subgroup_matrix_multiply_accumulate_tensor_float32' ] = 'cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32' in supported_extensions
158- dev_prop ['has_subgroup_2d_block_io' ] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
159- dev_prop ['has_bfloat16_conversions' ] = 'cl_intel_bfloat16_conversions' in supported_extensions
160- except subprocess .CalledProcessError as e :
161- raise RuntimeError (f'`ocloc` failed with error code { e .returncode } ' )
162- else :
163- dev_prop ['has_subgroup_matrix_multiply_accumulate' ] = tgt_prop .get (
164- 'has_subgroup_matrix_multiply_accumulate' , False )
165- dev_prop ['has_subgroup_matrix_multiply_accumulate_tensor_float32' ] = tgt_prop .get (
166- 'has_subgroup_matrix_multiply_accumulate_tensor_float32' , False )
167- dev_prop ['has_subgroup_2d_block_io' ] = tgt_prop .get ('has_subgroup_2d_block_io' , False )
168- dev_prop ['has_bfloat16_conversions' ] = tgt_prop .get ('has_bfloat16_conversions' , True )
174+ ocloc_dev_prop ['has_subgroup_2d_block_io' ] = 'cl_intel_subgroup_2d_block_io' in supported_extensions
175+ ocloc_dev_prop ['has_bfloat16_conversions' ] = 'cl_intel_bfloat16_conversions' in supported_extensions
176+ self .device_props [device_arch ] = ocloc_dev_prop
177+ dev_prop .update (ocloc_dev_prop )
178+ except subprocess .CalledProcessError :
179+ # Note: LTS driver does not support ocloc query CL_DEVICE_EXTENSIONS.
180+ pass
169181 return dev_prop
170182
171183 def parse_options (self , opts ) -> Any :
0 commit comments