diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 269682f3f2..6d23b8696c 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -1,4 +1,4 @@ -from triton.backends.compiler import BaseBackend, Language +from triton.backends.compiler import BaseBackend, GPUTarget, Language from triton._C.libtriton import ir, passes, llvm, intel from triton.backends.intel.driver import compile_module_from_src from triton.backends.intel.track import track @@ -15,6 +15,11 @@ import subprocess from pathlib import Path +try: # XPUBackend allows metaclasses injection + from .meta import XPUBackendMeta +except ImportError: + XPUBackendMeta = type(BaseBackend) + @dataclass class XPUOptions: @@ -63,40 +68,42 @@ def hash(self): return hashlib.sha256(key.encode("utf-8")).hexdigest() -def min_dot_size(device_props: dict): - # (M, N, K) - # M: repeatCount. 1,2,4,8 - # N: executionSize. 16 for PVC, 8 for ATS - # K: systolicDepth x opsPerChan. systolicDepth must be 8 - repeat_count = 1 - sdepth = 8 - exec_size = min(device_props["sub_group_sizes"]) - - def get_ops_per_channel(lhs_type, rhs_type): - l_bitwidth = lhs_type.scalar.primitive_bitwidth - r_bitwidth = rhs_type.scalar.primitive_bitwidth - max_ops_per_chan = 32 / max(l_bitwidth, r_bitwidth) - return min(8, max_ops_per_chan) - - return lambda lhs_type, rhs_type: (repeat_count, exec_size, sdepth * get_ops_per_channel(lhs_type, rhs_type)) - - -class XPUBackend(BaseBackend): +class XPUBackend(BaseBackend, metaclass=XPUBackendMeta): + arch_to_impl = {} # Architecture id to backend implementation class mapping + binary_ext = "spv" + target_arch = "spir64" + device_props: dict = {} instrumentation = None @staticmethod - def supports_target(target: tuple): + def supports_target(target: GPUTarget): return target.backend == 'xpu' - def __init__(self, target: tuple) -> None: - super().__init__(target) + def __new__(cls, target: GPUTarget): if not isinstance(target.arch, dict): raise TypeError("target.arch is not a dict") - dirname = os.path.dirname(os.path.realpath(__file__)) - mod = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(), name="arch_utils") - self.device_arch = knobs.intel.device_arch or mod.parse_device_arch(target.arch.get('architecture', 0)) + if cls is not XPUBackend: + return super().__new__(cls) + arch = target.arch.get("architecture", 0) + if (impl := cls.arch_to_impl.get(arch, None)) is None: + # Try to find an arch-specific implementation in the .arch. submodule. + if not (dev_arch := knobs.intel.device_arch): + dirname = os.path.dirname(os.path.realpath(__file__)) + parser = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(), + name="arch_utils") + dev_arch = parser.parse_device_arch(target.arch.get('architecture', 0)) + mod_name = f"{__package__}.arch.{dev_arch}" + try: + impl = __import__(mod_name, fromlist=["XPUBackendImpl"]).XPUBackendImpl + except ImportError: + impl = type(f"{mod_name}.XPUBackendImpl", (cls, ), {}) + impl.device_arch = dev_arch + cls.arch_to_impl[arch] = impl + return super().__new__(impl) + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) self.properties = self.parse_target(target.arch) - self.binary_ext = "spv" def get_target_name(self, options) -> str: return f"xpu:{self.device_arch}" @@ -120,21 +127,43 @@ def parse_target(self, tgt_prop) -> dict: dev_prop['has_subgroup_2d_block_io'] = tgt_prop.get('has_subgroup_2d_block_io', False) dev_prop['has_bfloat16_conversions'] = tgt_prop.get('has_bfloat16_conversions', True) + if self.device_arch in self.device_props: + dev_prop.update(self.device_props[self.device_arch]) + return dev_prop + return dev_prop def parse_options(self, opts) -> Any: - args = {k: opts[k] for k in XPUOptions.__dataclass_fields__.keys() if k in opts} + args = {k: v for k, v in opts.items() if k in XPUOptions.__dataclass_fields__} args["allow_fp8e4nv"] = True return XPUOptions(**args) def pack_metadata(self, metadata): return metadata + @staticmethod + def min_dot_size(device_props: dict): + # (M, N, K) + # M: repeatCount. 1,2,4,8 + # N: executionSize. 16 for PVC, 8 for ATS + # K: systolicDepth x opsPerChan. systolicDepth must be 8 + repeat_count = 1 + sdepth = 8 + exec_size = min(device_props["sub_group_sizes"]) + + def get_ops_per_channel(lhs_type, rhs_type): + l_bitwidth = lhs_type.scalar.primitive_bitwidth + r_bitwidth = rhs_type.scalar.primitive_bitwidth + max_ops_per_chan = 32 / max(l_bitwidth, r_bitwidth) + return min(8, max_ops_per_chan) + + return lambda lhs_type, rhs_type: (repeat_count, exec_size, sdepth * get_ops_per_channel(lhs_type, rhs_type)) + def get_codegen_implementation(self, options): from triton.language.extra.intel import convert_custom_float8 codegen_fns = {} codegen_fns["convert_custom_types"] = convert_custom_float8 - codegen_fns["min_dot_size"] = min_dot_size(self.properties) + codegen_fns["min_dot_size"] = self.min_dot_size(self.properties) return codegen_fns def get_module_map(self) -> Dict[str, ModuleType]: @@ -143,8 +172,8 @@ def get_module_map(self) -> Dict[str, ModuleType]: def load_dialects(self, ctx): intel.load_dialects(ctx) - if XPUBackend.instrumentation: - XPUBackend.instrumentation.load_dialects(ctx) + if self.instrumentation: + self.instrumentation.load_dialects(ctx) @staticmethod def validate_options(opt, properties): @@ -158,20 +187,15 @@ def validate_options(opt, properties): f"num_warps={opt.num_warps} is unsupported for the target (limit is {properties['max_num_sub_groups']})" ) - @staticmethod - def annotate_module(mod, properties, opt, target_arch): + @classmethod + def annotate_module(cls, module_opts, properties, opt): # Annotate module with information required by subsequent transformations. - pm = ir.pass_manager(mod.context) - pm.enable_debug() - module_opts = intel.passes.ttgpuir.AnnotateModuleOptions() module_opts.min_sg_size = min(properties["sub_group_sizes"]) module_opts.support_sg_2d_block = properties["has_subgroup_2d_block_io"] module_opts.support_dpas = properties["has_subgroup_matrix_multiply_accumulate"] module_opts.support_bf16_conversion = properties["has_bfloat16_conversions"] module_opts.threads_per_warp = opt.warp_size - module_opts.target_arch = target_arch - intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts) - pm.run(mod, 'annotate_module') + module_opts.target_arch = cls.target_arch @staticmethod def get_split_barrier_scope(opt): @@ -182,9 +206,9 @@ def get_split_barrier_scope(opt): split_barriers_scope = intel.SplitBarrierScope.Subgroup return split_barriers_scope - @staticmethod + @classmethod @track - def make_ttir(mod, metadata, opt): + def make_ttir(cls, mod, metadata, opt): pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) @@ -203,9 +227,9 @@ def make_ttir(mod, metadata, opt): pm.run(mod, 'make_ttir') return mod - @staticmethod + @classmethod @track - def make_ttgir(mod, metadata, opt, properties): + def make_ttgir(cls, mod, metadata, opt, properties): cluster_info = intel.ClusterInfo() if opt.cluster_dims is not None: cluster_info.clusterDimX = opt.cluster_dims[0] @@ -213,11 +237,16 @@ def make_ttgir(mod, metadata, opt, properties): cluster_info.clusterDimZ = opt.cluster_dims[2] # Annotate module with information required by subsequent transformations. - XPUBackend.annotate_module(mod, properties, opt, "spir64") + pm = ir.pass_manager(mod.context) + pm.enable_debug() + module_opts = intel.passes.ttgpuir.AnnotateModuleOptions() + cls.annotate_module(module_opts, properties, opt) + intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts) + pm.run(mod, 'annotate_module') # Overwrite the warp_size option with the module annotation. opt.warp_size = intel.get_threads_per_warp(mod) - XPUBackend.validate_options(opt, properties) + cls.validate_options(opt, properties) pm = ir.pass_manager(mod.context) pm.enable_debug() @@ -277,9 +306,15 @@ def gluon_to_ttgir(self, src, metadata, options): metadata["tensordesc_meta"] = mod.get_tensordesc_metadata() return mod - @staticmethod + @classmethod + def optimize_llvm_mod(cls, llvm_mod, options): + intel.set_spv_target_triple(llvm_mod) + with track("optimize_module") as tr: + intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes")) + + @classmethod @track - def make_llir(src, metadata, options): + def make_llir(cls, src, metadata, options): mod = src # TritonGPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) @@ -291,8 +326,8 @@ def make_llir(src, metadata, options): intel.passes.ttgpuir.add_allocate_shared_memory(pm) passes.ttgpuir.add_allocate_global_scratch_memory(pm) # instrumentation point here so we can override IRs above (e.g., ttir and ttgir) - if XPUBackend.instrumentation: - XPUBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context) + if cls.instrumentation: + cls.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context) intel.passes.ttgpuir.add_to_llvmir(pm) intel.passes.ttgpuir.add_gen_to_llvm(pm) passes.common.add_canonicalizer(pm) @@ -306,8 +341,8 @@ def make_llir(src, metadata, options): if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables: passes.llvmir.add_di_scope(pm) - if XPUBackend.instrumentation: - XPUBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context) + if cls.instrumentation: + cls.instrumentation.patch("llvmir_to_llvm", pm, mod.context) pm.run(mod, 'make_llir') if knobs.compilation.dump_ir_extract_di_local_variables: @@ -332,15 +367,12 @@ def make_llir(src, metadata, options): llvm.init_targets() context = llvm.context() llvm_mod = llvm.to_module(mod, context) - intel.set_spv_target_triple(llvm_mod) intel.set_fast_math(llvm_mod) if options.extern_libs: paths = [path for (name, path) in options.extern_libs] llvm.link_extern_libs(llvm_mod, paths) - with track("optimize_module") as tr: - intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes")) - + cls.optimize_llvm_mod(llvm_mod, options) intel.post_process_llir(llvm_mod) # Get some metadata @@ -358,9 +390,9 @@ def make_llir(src, metadata, options): del context return ret - @staticmethod + @classmethod @track - def make_spv(src, metadata, options, device_arch): + def make_spv(cls, src, metadata, options): spirv, name = intel.translate_to_spirv(src) metadata["name"] = name if options.grf_mode == 'small': @@ -393,7 +425,7 @@ def make_spv(src, metadata, options, device_arch): fbin = fsrc.name + '.o' ocloc_cmd = [ - 'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, + 'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', cls.device_arch, '-options', metadata["build_flags"] + shader_dump_opt ] @@ -433,7 +465,7 @@ def add_stages(self, stages, options, language): elif language == Language.GLUON: stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) - stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options, self.device_arch) + stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options) if knobs.runtime.add_stages_inspection_hook is not None: knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)