Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 101 additions & 70 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from triton.backends.compiler import BaseBackend, Language
from triton.backends.compiler import BaseBackend, GPUTarget, Language
from triton._C.libtriton import ir, passes, llvm, intel
from triton.backends.intel.driver import compile_module_from_src
from triton.backends.intel.track import track
Expand All @@ -15,6 +15,11 @@
import subprocess
from pathlib import Path

try: # XPUBackend allows metaclasses injection
from .meta import XPUBackendMeta
except ImportError:
XPUBackendMeta = type(BaseBackend)


@dataclass
class XPUOptions:
Expand Down Expand Up @@ -63,40 +68,42 @@ def hash(self):
return hashlib.sha256(key.encode("utf-8")).hexdigest()


def min_dot_size(device_props: dict):
# (M, N, K)
# M: repeatCount. 1,2,4,8
# N: executionSize. 16 for PVC, 8 for ATS
# K: systolicDepth x opsPerChan. systolicDepth must be 8
repeat_count = 1
sdepth = 8
exec_size = min(device_props["sub_group_sizes"])

def get_ops_per_channel(lhs_type, rhs_type):
l_bitwidth = lhs_type.scalar.primitive_bitwidth
r_bitwidth = rhs_type.scalar.primitive_bitwidth
max_ops_per_chan = 32 / max(l_bitwidth, r_bitwidth)
return min(8, max_ops_per_chan)

return lambda lhs_type, rhs_type: (repeat_count, exec_size, sdepth * get_ops_per_channel(lhs_type, rhs_type))


class XPUBackend(BaseBackend):
class XPUBackend(BaseBackend, metaclass=XPUBackendMeta):
arch_to_impl = {} # Architecture id to backend implementation class mapping
binary_ext = "spv"
target_arch = "spir64"
device_props: dict = {}
instrumentation = None

@staticmethod
def supports_target(target: tuple):
def supports_target(target: GPUTarget):
return target.backend == 'xpu'

def __init__(self, target: tuple) -> None:
super().__init__(target)
def __new__(cls, target: GPUTarget):
if not isinstance(target.arch, dict):
raise TypeError("target.arch is not a dict")
dirname = os.path.dirname(os.path.realpath(__file__))
mod = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(), name="arch_utils")
self.device_arch = knobs.intel.device_arch or mod.parse_device_arch(target.arch.get('architecture', 0))
if cls is not XPUBackend:
return super().__new__(cls)
arch = target.arch.get("architecture", 0)
if (impl := cls.arch_to_impl.get(arch, None)) is None:
# Try to find an arch-specific implementation in the .arch.<name> submodule.
if not (dev_arch := knobs.intel.device_arch):
dirname = os.path.dirname(os.path.realpath(__file__))
parser = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(),
name="arch_utils")
dev_arch = parser.parse_device_arch(target.arch.get('architecture', 0))
mod_name = f"{__package__}.arch.{dev_arch}"
try:
impl = __import__(mod_name, fromlist=["XPUBackendImpl"]).XPUBackendImpl
except ImportError:
impl = type(f"{mod_name}.XPUBackendImpl", (cls, ), {})
impl.device_arch = dev_arch
cls.arch_to_impl[arch] = impl
return super().__new__(impl)

def __init__(self, target: GPUTarget) -> None:
super().__init__(target)
self.properties = self.parse_target(target.arch)
self.binary_ext = "spv"

def get_target_name(self, options) -> str:
return f"xpu:{self.device_arch}"
Expand All @@ -120,21 +127,43 @@ def parse_target(self, tgt_prop) -> dict:
dev_prop['has_subgroup_2d_block_io'] = tgt_prop.get('has_subgroup_2d_block_io', False)
dev_prop['has_bfloat16_conversions'] = tgt_prop.get('has_bfloat16_conversions', True)

if self.device_arch in self.device_props:
dev_prop.update(self.device_props[self.device_arch])
return dev_prop

return dev_prop

def parse_options(self, opts) -> Any:
args = {k: opts[k] for k in XPUOptions.__dataclass_fields__.keys() if k in opts}
args = {k: v for k, v in opts.items() if k in XPUOptions.__dataclass_fields__}
args["allow_fp8e4nv"] = True
return XPUOptions(**args)

def pack_metadata(self, metadata):
return metadata

@staticmethod
def min_dot_size(device_props: dict):
# (M, N, K)
# M: repeatCount. 1,2,4,8
# N: executionSize. 16 for PVC, 8 for ATS
# K: systolicDepth x opsPerChan. systolicDepth must be 8
repeat_count = 1
sdepth = 8
exec_size = min(device_props["sub_group_sizes"])

def get_ops_per_channel(lhs_type, rhs_type):
l_bitwidth = lhs_type.scalar.primitive_bitwidth
r_bitwidth = rhs_type.scalar.primitive_bitwidth
max_ops_per_chan = 32 / max(l_bitwidth, r_bitwidth)
return min(8, max_ops_per_chan)

return lambda lhs_type, rhs_type: (repeat_count, exec_size, sdepth * get_ops_per_channel(lhs_type, rhs_type))

def get_codegen_implementation(self, options):
from triton.language.extra.intel import convert_custom_float8
codegen_fns = {}
codegen_fns["convert_custom_types"] = convert_custom_float8
codegen_fns["min_dot_size"] = min_dot_size(self.properties)
codegen_fns["min_dot_size"] = self.min_dot_size(self.properties)
return codegen_fns

def get_module_map(self) -> Dict[str, ModuleType]:
Expand All @@ -143,8 +172,8 @@ def get_module_map(self) -> Dict[str, ModuleType]:

def load_dialects(self, ctx):
intel.load_dialects(ctx)
if XPUBackend.instrumentation:
XPUBackend.instrumentation.load_dialects(ctx)
if self.instrumentation:
self.instrumentation.load_dialects(ctx)

@staticmethod
def validate_options(opt, properties):
Expand All @@ -158,20 +187,15 @@ def validate_options(opt, properties):
f"num_warps={opt.num_warps} is unsupported for the target (limit is {properties['max_num_sub_groups']})"
)

@staticmethod
def annotate_module(mod, properties, opt, target_arch):
@classmethod
def annotate_module(cls, module_opts, properties, opt):
# Annotate module with information required by subsequent transformations.
pm = ir.pass_manager(mod.context)
pm.enable_debug()
module_opts = intel.passes.ttgpuir.AnnotateModuleOptions()
module_opts.min_sg_size = min(properties["sub_group_sizes"])
module_opts.support_sg_2d_block = properties["has_subgroup_2d_block_io"]
module_opts.support_dpas = properties["has_subgroup_matrix_multiply_accumulate"]
module_opts.support_bf16_conversion = properties["has_bfloat16_conversions"]
module_opts.threads_per_warp = opt.warp_size
module_opts.target_arch = target_arch
intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts)
pm.run(mod, 'annotate_module')
module_opts.target_arch = cls.target_arch

@staticmethod
def get_split_barrier_scope(opt):
Expand All @@ -182,11 +206,16 @@ def get_split_barrier_scope(opt):
split_barriers_scope = intel.SplitBarrierScope.Subgroup
return split_barriers_scope

@staticmethod
@track
def make_ttir(mod, metadata, opt):
pm = ir.pass_manager(mod.context)
@classmethod
def create_pass_manager(cls, context):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the benefit of create_pass_manager? is it going to be different for other arch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the initial version it also added passes to the pm - 65c9f8a#diff-5ec0c116ff4361ab1b5db32636583aaf2da599bda46558aab877dc5670a77a39R206 . I've removed this functionality, but kept the method to save on duplication of 2 lines of code :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it is ok not to save 2 lines of code duplication in exchange for easy comparison with other backends.

pm = ir.pass_manager(context)
pm.enable_debug()
return pm

@classmethod
@track
def make_ttir(cls, mod, metadata, opt):
pm = cls.create_pass_manager(mod.context)
passes.common.add_inliner(pm)
intel.passes.ttir.add_convert_tdesc_to_block_pointer(pm)
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
Expand All @@ -203,24 +232,27 @@ def make_ttir(mod, metadata, opt):
pm.run(mod, 'make_ttir')
return mod

@staticmethod
@classmethod
@track
def make_ttgir(mod, metadata, opt, properties):
def make_ttgir(cls, mod, metadata, opt, properties):
cluster_info = intel.ClusterInfo()
if opt.cluster_dims is not None:
cluster_info.clusterDimX = opt.cluster_dims[0]
cluster_info.clusterDimY = opt.cluster_dims[1]
cluster_info.clusterDimZ = opt.cluster_dims[2]

# Annotate module with information required by subsequent transformations.
XPUBackend.annotate_module(mod, properties, opt, "spir64")
pm = cls.create_pass_manager(mod.context)
module_opts = intel.passes.ttgpuir.AnnotateModuleOptions()
cls.annotate_module(module_opts, properties, opt)
intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts)
pm.run(mod, 'annotate_module')

# Overwrite the warp_size option with the module annotation.
opt.warp_size = intel.get_threads_per_warp(mod)
XPUBackend.validate_options(opt, properties)
cls.validate_options(opt, properties)

pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm = cls.create_pass_manager(mod.context)
passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.warp_size, opt.num_ctas)
# optimize TTGIR
intel.passes.ttgpuir.add_coalesce(pm)
Expand Down Expand Up @@ -277,22 +309,26 @@ def gluon_to_ttgir(self, src, metadata, options):
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
return mod

@staticmethod
@classmethod
def optimize_llvm_mod(cls, llvm_mod, options):
intel.set_spv_target_triple(llvm_mod)
with track("optimize_module") as tr:
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes"))

@classmethod
@track
def make_llir(src, metadata, options):
def make_llir(cls, src, metadata, options):
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()

pm = cls.create_pass_manager(mod.context)
passes.convert.add_scf_to_cf(pm)
passes.gluon.add_inliner(pm)
passes.convert.add_index_to_llvmir(pm)
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
if XPUBackend.instrumentation:
XPUBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
if cls.instrumentation:
cls.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
intel.passes.ttgpuir.add_to_llvmir(pm)
intel.passes.ttgpuir.add_gen_to_llvm(pm)
passes.common.add_canonicalizer(pm)
Expand All @@ -306,15 +342,14 @@ def make_llir(src, metadata, options):
if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables:
passes.llvmir.add_di_scope(pm)

if XPUBackend.instrumentation:
XPUBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
if cls.instrumentation:
cls.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
pm.run(mod, 'make_llir')

if knobs.compilation.dump_ir_extract_di_local_variables:
# comments below on why separate it
if not knobs.compilation.disable_line_info:
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm = cls.create_pass_manager(mod.context)
passes.llvmir.add_di_scope(pm)
pm.run(mod, 'make_llir.disable_line_info')

Expand All @@ -323,24 +358,20 @@ def make_llir(src, metadata, options):
# pass and add_di_scope has to be run separately, otherwise if we
# put them into previous pipline, it trigger a segmentfault without
# any error message; could be due to a bug in mlir or pybind11
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm = cls.create_pass_manager(mod.context)
passes.llvmir.add_di_local_variable(pm)
pm.run(mod, 'make_llir.dump_ir_extract_di_local_variables')

# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
llvm.init_targets()
context = llvm.context()
llvm_mod = llvm.to_module(mod, context)
intel.set_spv_target_triple(llvm_mod)
intel.set_fast_math(llvm_mod)
if options.extern_libs:
paths = [path for (name, path) in options.extern_libs]
llvm.link_extern_libs(llvm_mod, paths)

with track("optimize_module") as tr:
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes"))

cls.optimize_llvm_mod(llvm_mod, options)
intel.post_process_llir(llvm_mod)

# Get some metadata
Expand All @@ -358,9 +389,9 @@ def make_llir(src, metadata, options):
del context
return ret

@staticmethod
@classmethod
@track
def make_spv(src, metadata, options, device_arch):
def make_spv(cls, src, metadata, options):
spirv, name = intel.translate_to_spirv(src)
metadata["name"] = name
if options.grf_mode == 'small':
Expand Down Expand Up @@ -393,7 +424,7 @@ def make_spv(src, metadata, options, device_arch):
fbin = fsrc.name + '.o'

ocloc_cmd = [
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch,
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', cls.device_arch,
'-options', metadata["build_flags"] + shader_dump_opt
]

Expand Down Expand Up @@ -433,7 +464,7 @@ def add_stages(self, stages, options, language):
elif language == Language.GLUON:
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options, self.device_arch)
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options)
if knobs.runtime.add_stages_inspection_hook is not None:
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)

Expand Down
Loading