diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 155c2824e..2928cbc14 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -155,10 +155,12 @@ def visit_Return(self, node: ast.Return) -> bool: def visit_Assign(self, node: ast.Assign) -> bool: # There couldn't be an early return + # x = ... return False def visit_AugAssign(self, node: ast.AugAssign) -> bool: # There couldn't be an early return + # x += ... return False def visit_Module(self, node: ast.Module) -> bool: @@ -168,6 +170,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: return self._visit_stmts(node.body) def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return ret = self._visit_stmts(node.body) if node.orelse: ret = ret or self._visit_stmts(node.orelse) @@ -192,6 +201,9 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.begin_line = begin_line - 1 self.builder.set_loc(file_name, begin_line, 0) self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) self.builder.codegen_fns = codegen_fns self.builder.module_map = {} if module_map is None else module_map self.module = self.builder.create_module() if module is None else module @@ -474,7 +486,10 @@ def visit_AnnAssign(self, node): return self.visit_Assign(node) def visit_Assign(self, node): - # flagtree: First, do normal assignment processing + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("anno_CodeGenerator_visit_Assign") + _names = [] if isinstance(node, ast.AnnAssign): _names += [self.visit(node.target)] @@ -498,30 +513,10 @@ def visit_Assign(self, node): not isinstance(value, native_nontensor_types): value = language.semantic.to_tensor(value, self.builder) self.set_value(name, value) - - # flagtree: After normal processing, check if we need to add hint annotation - if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): - line_num = node.lineno - # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later - function_def = self.jit_fn.parse() - line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) - flagtree_hints = line_flagtree_hints.get(line_num) - - # Check if this is a tl.load call with dot_pad_only_k hint - if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and - isinstance(node.value, ast.Call) and - isinstance(node.value.func, ast.Attribute) and - isinstance(node.value.func.value, ast.Name) and - node.value.func.value.id == 'tl' and - node.value.func.attr == 'load'): - - # Add hint annotation to the loaded tensor(s) - for name, value in zip(names, values): - if _is_triton_value(value): - # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") - # Create hint annotation - hint_val = self.builder.get_unit_attr() - self.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values) def visit_AugAssign(self, node): name = node.target.id @@ -828,6 +823,8 @@ def visit_While(self, node): liveins, insert_block = sr ip, last_loc = self._get_insertion_point_and_loc() + # loop body (the after region) + # loop_block = self.builder.create_block() dummy = self.builder.create_block() self.builder.set_insertion_point_to_start(dummy) self.scf_stack.append(node) @@ -921,8 +918,11 @@ def visit_For(self, node): return num_stages = None loop_unroll_factor = None - bind_sub_block = None - if IteratorClass in [language.range, language.parallel]: + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + bind_sub_block = flagtree_backend_specialization("init_bind_sub_block") + if IteratorClass in [language.range] + ([language.parallel] if flagtree_backend_specialization("is_visit_For_support_parallel") else []): iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments # note: only `range` iterator is supported now @@ -932,8 +932,9 @@ def visit_For(self, node): step = iterator.step num_stages = iterator.num_stages loop_unroll_factor = iterator.loop_unroll_factor - if (IteratorClass is language.parallel): - bind_sub_block = iterator.bind_sub_block + + #flagtree backend specialization + bind_sub_block = flagtree_backend_specialization("set_bind_sub_block_when_parallel", IteratorClass, iterator, bind_sub_block) elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now @@ -943,20 +944,10 @@ def visit_For(self, node): step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) else: raise RuntimeError('Only `range` and `static_range` iterators are currently supported') - - # flagtree: After normal processing, check if we need to override bind_sub_block - if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): - line_num = node.lineno - # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later - function_def = self.jit_fn.parse() - line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) - flagtree_hints = line_flagtree_hints.get(line_num) - - # Check if this is a range/for loop with bind_sub_block hint - if flagtree_hints and 'bind_sub_block' in flagtree_hints: - bind_sub_block = True - # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") - + + #flagtree backend specialization + bind_sub_block = flagtree_backend_specialization("check_override_bind_sub_block", self, node, bind_sub_block) + # handle negative constant step (not supported by scf.for in MLIR) negative_step = False if _is_constexpr(step) and step.value < 0: @@ -1021,7 +1012,8 @@ def visit_For(self, node): if loop_unroll_factor is not None: for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) if (bind_sub_block is not None) and bind_sub_block: - for_op.set_attr("bind_sub_block", self.builder.get_bool_attr(bind_sub_block)) + #flagtree backend specialization + flagtree_backend_specialization("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) @@ -1105,7 +1097,11 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): generator.visit(fn.parse()) except Exception as e: # Wrap the error in the callee with the location of the call. - raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from e + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + raise CompilationError(self.jit_fn.src, self.cur_node, + repr(e) if flagtree_backend_specialization('need_repr_in_CodeGenerator_CompilationError') else None) from e callee_ret_type = generator.ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -1149,7 +1145,11 @@ def visit_Call(self, node): # itself). But when calling a function, we raise as `from e` to # preserve the traceback of the original error, which may e.g. # be in core.py. - raise CompilationError(self.jit_fn.src, node, repr(e)) from e + + #flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + raise CompilationError(self.jit_fn.src, node, + repr(e) if flagtree_backend_specialization('need_repr_in_CodeGenerator_CompilationError') else None) from e if fn in self.builtin_namespace.values(): args = map(_unwrap_if_constexpr, args) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 6a8359d6f..a7bc26bd2 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -4,7 +4,6 @@ from .._C.libtriton import get_cache_invalidating_env_vars, ir from ..backends import backends from ..backends.compiler import GPUTarget, AttrsDescriptor -from ..backends.ascend.compiler import AscendAttrsDescriptor from .. import __version__ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager @@ -12,7 +11,6 @@ from ..tools.disasm import get_sass # TODO: this shouldn't be here from .code_generator import ast_to_ttir -from .errors import MLIRCompilationError from pathlib import Path import re import functools @@ -87,8 +85,9 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: for k in self.constants.keys(): if not isinstance(k, str): raise TypeError("Constants keys must be string") - if self.attrs is None: - self.attrs = AscendAttrsDescriptor() + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("ext_ASTSource_attrs", self) def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] @@ -252,12 +251,11 @@ def compile(src, target=None, options=None): # cache hit! metadata = json.loads(Path(metadata_path).read_text()) return CompiledKernel(src, metadata_group, hash) - compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1') - if (compile_speed_opt): - ttir_path = f"{file_name}.ttir" - if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)): - # Already compile once but failed. So directly return - raise Exception("already failed once") + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("opt_ascend_compile_speed", file_name, metadata_path, fn_cache_manager) + # initialize metadata metadata = { "hash": hash, @@ -287,14 +285,10 @@ def compile(src, target=None, options=None): try: next_module = compile_ir(module, metadata) except Exception as e: - if (ext == "ttadapter"): - stage_name = "ConvertTritonIRToLinalgIR" - elif (ext == "npubin"): - stage_name = "ConvertLinalgRToBinary" - else: - stage_name = "MLIRCompile" - error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) - raise MLIRCompilationError(stage_name, error_detail) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("handle_compile_error", e, ext) + ir_filename = f"{file_name}.{ext}" if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): print(f"\nOverriding kernel with file {full_name}") @@ -406,9 +400,12 @@ def _init_handles(self): # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( self.name, self.kernel, self.metadata.shared, device) - - # This mechanism introduces heavy runtime overhead. - # Commenting __getattribute__ requires explicitly calling _init_handles() + def __getattribute__(self, name): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if name == 'run' and flagtree_backend_specialization("is_CompiledKernel_getattribute_need_init_handles"): + self._init_handles() + return super().__getattribute__(name) def launch_metadata(self, grid, stream, *args): if CompiledKernel.launch_enter_hook is None: @@ -431,8 +428,11 @@ def __getitem__(self, grid): self._init_handles() def runner(*args, stream=None): - if stream is None: - stream = self.metadata.stream + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("set_CompiledKernel_metadata_stream", self, stream) + if stream is None: device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) diff --git a/python/triton/compiler/errors.py b/python/triton/compiler/errors.py index 5242258ad..6d1cc470a 100644 --- a/python/triton/compiler/errors.py +++ b/python/triton/compiler/errors.py @@ -51,18 +51,6 @@ class UnsupportedLanguageConstruct(CompilationError): pass -class MLIRCompilationError(TritonError): - def __init__(self, stage_name: Optional[str], message: Optional[str] = None): - self.stage_name = stage_name - self.message = f"\n" \ - f"{self.format_line_delim('[ERROR][Triton][BEG]')}" \ - f"[{self.stage_name}] encounters error:\n" \ - f"{self.filter_message(message)}" \ - f"{self.format_line_delim('[ERROR][Triton][END]')}" - def __str__(self): - return self.message - def filter_message(self, message): - # Content starting from "Stack dump without symbol names" means nothing to the users - return message.split("Stack dump without symbol names")[0] - def format_line_delim(self, keyword): - return f"///------------------{keyword}------------------\n" \ No newline at end of file +# flagtree backend specialization +from triton.runtime.driver import flagtree_backend_class_specialization +MLIRCompilationError = flagtree_backend_class_specialization("MLIRCompilationError") diff --git a/python/triton/language/_utils.py b/python/triton/language/_utils.py index d0ca8c734..df9eaeaaf 100644 --- a/python/triton/language/_utils.py +++ b/python/triton/language/_utils.py @@ -1,19 +1,28 @@ from __future__ import annotations -from typing import List, TYPE_CHECKING, Any, Union, Dict +from typing import List +# flagtree backend specialization +from triton.runtime.driver import flagtree_backend_specialization +from typing import TYPE_CHECKING if TYPE_CHECKING: - from .language import core - IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] - ObjPath = tuple[int, ...] + IterableType, ObjPath = flagtree_backend_specialization('get_language_utils_IterableType_ObjPath') + + +TRITON_MAX_TENSOR_NUMEL = flagtree_backend_specialization('get_triton_max_tensor_numel') + + +def is_power_of_two(x): + return (x & (x - 1)) == 0 -TRITON_MAX_TENSOR_NUMEL = 1048576 def validate_block_shape(shape: List[int]): numel = 1 for i, d in enumerate(shape): if not isinstance(d, int): raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + if flagtree_backend_specialization('is_block_shape_check_power_of_two') and not is_power_of_two(d): + raise ValueError(f"Shape element {i} must be a power of 2") numel *= d if numel > TRITON_MAX_TENSOR_NUMEL: @@ -21,19 +30,11 @@ def validate_block_shape(shape: List[int]): return numel -BITWIDTH_DICT: Dict[str, int] = { - **{f"u{n}": n - for n in (1, 8, 16, 32, 64)}, - **{f"i{n}": n - for n in (1, 8, 16, 32, 64)}, - **{f"fp{n}": n - for n in (16, 32, 64)}, - **{f"fp8{suffix}": 8 - for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")}, - "bf16": 16, - "void": 0, -} +# flagtree backend specialization +from triton.runtime.driver import flagtree_backend_specialization +BITWIDTH_DICT = flagtree_backend_specialization('get_language_utils_BITWIDTH_DICT') -def get_primitive_bitwidth(dtype: str) -> int: - return BITWIDTH_DICT[dtype] +# flagtree backend specialization +from triton.runtime.driver import flagtree_backend_func_specialization +get_primitive_bitwidth = flagtree_backend_func_specialization("get_primitive_bitwidth") diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 67753e129..95061c610 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -4,7 +4,7 @@ import os import time import inspect -from typing import Dict, List +from typing import Dict from .jit import KernelInterface from .errors import OutOfResources @@ -28,6 +28,7 @@ def __init__( rep=None, use_cuda_graph=False, do_bench=None, + # flagtree backend specialization auto_profile_dir=None, ): """ @@ -36,9 +37,15 @@ def __init__( 'top_k': number of configs to bench 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. """ + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if not configs: self.configs = [ - Config({}) + flagtree_backend_specialization('get_spec_default_Autotuner_configs') + if flagtree_backend_specialization('has_spec_default_Autotuner_configs') + else Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0) ] else: self.configs = configs @@ -100,7 +107,9 @@ def _post_hook(kwargs, exception): self.num_warmups = warmup self.num_reps = rep self.use_cuda_graph = use_cuda_graph - self.auto_profile_dir = auto_profile_dir + + # flagtree backend specialization + flagtree_backend_specialization('set_Autotuner_auto_profile_dir', self, auto_profile_dir) # If we got explicitly called via the old interface, raise a warning # and proceed with the old behavior. @@ -133,7 +142,7 @@ def _post_hook(kwargs, exception): self.do_bench = do_bench def _bench(self, *args, config, **meta): - from ..compiler.errors import CompileTimeAssertionFailure, MLIRCompilationError + from ..compiler.errors import CompileTimeAssertionFailure # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner @@ -163,46 +172,15 @@ def kernel_call(): self.post_hook(full_nargs, exception=None) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + try: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) - except (OutOfResources, CompileTimeAssertionFailure, MLIRCompilationError) as e: + except (OutOfResources, CompileTimeAssertionFailure) + \ + flagtree_backend_specialization("ext_Autotuner_do_bench_MLIRCompilationError") as e: return [float("inf"), float("inf"), float("inf")] - def _profile(self, *args, config, **meta): - from triton.testing import do_bench_npu - - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols.") - # augment meta-parameters with tunable ones - current = dict(meta, **config.all_kwargs()) - full_nargs = {**self.nargs, **current} - - def kernel_call(): - if config.pre_hook: - config.pre_hook(full_nargs) - self.pre_hook(full_nargs) - try: - self.fn.run( - *args, - **current, - ) - except Exception as e: - try: - self.post_hook(full_nargs, exception=e) - finally: - # Throw exception raised by `self.fn.run` - raise - - self.post_hook(full_nargs, exception=None) - - do_bench_npu( - kernel_call, prof_dir=self.auto_profile_dir, keep_res=True - ) - def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) used_cached_result = True @@ -234,8 +212,10 @@ def run(self, *args, **kwargs): print(f"Triton autotuning for function {self.base_fn.__name__} finished after " f"{self.bench_time:.2f}s; best config selected: {self.best_config};") - if not used_cached_result and self.auto_profile_dir is not None: - self._profile(*args, config=self.best_config, **kwargs) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_Autotuner_profile', self, used_cached_result, args, kwargs) + if config.pre_hook is not None: full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} config.pre_hook(full_nargs) @@ -315,18 +295,14 @@ def __init__(self, kwargs, num_warps=None, num_stages=None, num_ctas=None, num_b self.maxnreg = maxnreg self.pre_hook = pre_hook - - # BiShengIR Options allowed for autotune - self.multibuffer = bishengir_options.get("multibuffer", None) # Compiler Default True - self.unit_flag = bishengir_options.get("unit_flag", None) # Compiler Default False - self.limit_auto_multi_buffer_only_for_local_buffer = bishengir_options.get("limit_auto_multi_buffer_only_for_local_buffer", None) # Compiler Default False - self.limit_auto_multi_buffer_of_local_buffer = bishengir_options.get("limit_auto_multi_buffer_of_local_buffer", None) # Compiler Default no-limit - self.set_workspace_multibuffer = bishengir_options.get("set_workspace_multibuffer", None) # Compiler Default 1 - self.enable_hivm_auto_cv_balance = bishengir_options.get("enable_hivm_auto_cv_balance", None) # Compiler Default True - self.tile_mix_vector_loop = bishengir_options.get("tile_mix_vector_loop", None) # Compiler Default 1 - self.tile_mix_cube_loop = bishengir_options.get("tile_mix_cube_loop", None) # Compiler Default 1 + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('set_Config_BiShengIR_options', self, bishengir_options) def all_kwargs(self): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + return { **self.kwargs, **{ k: v @@ -339,17 +315,7 @@ def all_kwargs(self): ("reg_dec_producer", self.reg_dec_producer), ("reg_inc_consumer", self.reg_inc_consumer), ("maxnreg", self.maxnreg), - - ("multibuffer", self.multibuffer), - ("enable_hivm_auto_cv_balance", self.enable_hivm_auto_cv_balance), - ("unit_flag", self.unit_flag), - ("limit_auto_multi_buffer_only_for_local_buffer", \ - self.limit_auto_multi_buffer_only_for_local_buffer), - ("limit_auto_multi_buffer_of_local_buffer", self.limit_auto_multi_buffer_of_local_buffer), - ("set_workspace_multibuffer", self.set_workspace_multibuffer), - ("tile_mix_vector_loop", self.tile_mix_vector_loop), - ("tile_mix_cube_loop", self.tile_mix_cube_loop), - ) if v is not None + ) + flagtree_backend_specialization('ext_Config_all_kwargs', self) if v is not None } } @@ -366,15 +332,10 @@ def __str__(self): res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") res.append(f"maxnreg: {self.maxnreg}") - res.append(f"multibuffer: {self.multibuffer}") - res.append(f"enable_hivm_auto_cv_balance: {self.enable_hivm_auto_cv_balance}") - res.append(f"unit_flag: {self.unit_flag}") - res.append(f"limit_auto_multi_buffer_only_for_local_buffer: \ - {self.limit_auto_multi_buffer_only_for_local_buffer}") - res.append(f"limit_auto_multi_buffer_of_local_buffer: {self.limit_auto_multi_buffer_of_local_buffer}") - res.append(f"set_workspace_multibuffer: {self.set_workspace_multibuffer}") - res.append(f"tile_mix_vector_loop: {self.tile_mix_vector_loop}") - res.append(f"tile_mix_cube_loop: {self.tile_mix_cube_loop}") + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization('ext_Config_to_str', res, self) + return ", ".join(res) @@ -440,15 +401,16 @@ def kernel(x_ptr, x_size, **META): """ def decorator(fn): + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization if split_params or tiling_params: - from .autotiling_tuner import AutoTilingTuner - return AutoTilingTuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, - post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, - use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir, - split_params=split_params, tiling_params=tiling_params, low_dims=low_dims, - dual_reduction=dual_reduction, persistent_reduction=persistent_reduction) - else: - return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + return flagtree_backend_specialization('new_AutoTilingTuner', fn, configs, key, reset_to_zero, restore_value, pre_hook, + post_hook, prune_configs_by, warmup, rep, + use_cuda_graph, do_bench, auto_profile_dir, + split_params, tiling_params, low_dims, + dual_reduction, persistent_reduction) + + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir) diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index c3b97a764..513c585bb 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -58,3 +58,35 @@ def reset_active(self): driver = DriverConfig() + + +# flagtree backend specialization +def flagtree_backend_specialization(function_name: str, *args, **kwargs): + if hasattr(driver.active, "flagtree_backend_specialization"): + flagtree_backend_specialization = driver.active.flagtree_backend_specialization + if hasattr(flagtree_backend_specialization, function_name): + func = getattr(flagtree_backend_specialization, function_name) + return func(*args, **kwargs) + raise RuntimeError(f"{function_name} not found in flagtree_backend_specialization") + raise RuntimeError(f"flagtree_backend_specialization not found in {driver.active}") + + +# flagtree backend func specialization +def flagtree_backend_func_specialization(function_name: str): + if hasattr(driver.active, "flagtree_backend_specialization"): + flagtree_backend_specialization = driver.active.flagtree_backend_specialization + if hasattr(flagtree_backend_specialization, function_name): + func = getattr(flagtree_backend_specialization, function_name) + return func + raise RuntimeError(f"{function_name} not found in flagtree_backend_specialization") + raise RuntimeError(f"flagtree_backend_specialization not found in {driver.active}") + +# flagtree backend class specialization +def flagtree_backend_class_specialization(class_name: str): + if hasattr(driver.active, "flagtree_backend_specialization"): + flagtree_backend_specialization = driver.active.flagtree_backend_specialization + if hasattr(flagtree_backend_specialization, class_name): + cls = getattr(flagtree_backend_specialization, class_name) + return cls + raise RuntimeError(f"{class_name} not found in flagtree_backend_specialization") + raise RuntimeError(f"flagtree_backend_specialization not found in {driver.active}") diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 08422611f..aa7d8b800 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -6,14 +6,11 @@ import os import re import textwrap -import tokenize from collections import defaultdict from functools import cached_property from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple from ..runtime.driver import driver -from ..backends.ascend.compiler import AscendAttrsDescriptor from types import ModuleType -from io import StringIO TRITON_MODULE = __name__[:-len(".runtime.jit")] @@ -331,6 +328,7 @@ def __getitem__(self, grid) -> T: memorizes the grid. """ return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) def serialize_specialization_data(name, signature, constants, attrs, options, key): @@ -568,7 +566,10 @@ def run(self, *args, grid, warmup, **kwargs): # parse options from ..compiler import make_backend device = driver.active.get_current_device() - if ('stream' not in kwargs.keys()): + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization("is_set_stream_in_kwargs", kwargs): stream = driver.active.get_current_stream(device) target = driver.active.get_current_target() backend = make_backend(target) @@ -593,20 +594,17 @@ def run(self, *args, grid, warmup, **kwargs): # deprecated arguments assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" assert "device" not in kwargs, "device option is deprecated; current device will be used" - # assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization("is_stream_option_deprecated"): + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: if k not in options.__dict__: raise KeyError("Keyword argument %s was specified but unrecognised" % k) - ignor_params = ["debug", "sanitize_overflow", "llvm_version", "kernel_name", \ - "allowed_dot_input_precisions", "multibuffer", "stream"] - not_work_params = [] - for k in kwargs: - if k in ignor_params: - continue - elif k in excess_kwargs: - not_work_params.append(k) - if len(not_work_params) != 0: - print("[WARNING] Please DO NOT tune args {}!".format(not_work_params)) + + flagtree_backend_specialization("ignore_params_in_JITFunction_run", kwargs, excess_kwargs) bound_vals = tuple(bound_args.values()) @@ -660,16 +658,17 @@ def run(self, *args, grid, warmup, **kwargs): grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 - grid_all_size = grid_0 * grid_1 * grid_2 - if os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "0") == "0": - if grid_all_size > 65535: - raise RuntimeError("grid should be less than 65536! You can try \"export TRITON_ALL_BLOCKS_PARALLEL=1\" to avoid this problem.") - if ('stream' in kwargs.keys()): - stream = kwargs["stream"] + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + flagtree_backend_specialization("check_grid_size", grid_0, grid_1, grid_2) + stream = flagtree_backend_specialization("set_stream_from_kwargs", kwargs, stream) + # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) - # explicitly define run method and load kernel binary - kernel._init_handles() + + flagtree_backend_specialization("explicit_load_kernel_library", kernel) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) return kernel @@ -749,6 +748,7 @@ def warmup(self, *args, grid, **kwargs): def preload(self, specialization_data): from ..compiler import compile, ASTSource + from triton.backends.compiler import AttrsDescriptor import json import triton.language as tl device = driver.active.get_current_device() @@ -761,7 +761,13 @@ def preload(self, specialization_data): for key, value in deserialized_obj['constants'].items() } signature = dict(deserialized_obj['signature'].items()) - src = ASTSource(self, signature, constants, AscendAttrsDescriptor.from_dict(deserialized_obj['attrs'])) + + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + src = ASTSource(self, signature, constants, + flagtree_backend_specialization('get_JITFunction_spec_attr', deserialized_obj) + if flagtree_backend_specialization('is_JITFunction_spec_attr') + else AttrsDescriptor.from_dict(deserialized_obj['attrs'])) options = { key: tuple(value) if isinstance(value, list) else value for key, value in deserialized_obj['options'].items() @@ -775,29 +781,18 @@ def preload(self, specialization_data): # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. def parse(self): - # Maps line numbers to comment hints - line_flagtree_hints = {} - code_str = self.src - g = tokenize.generate_tokens(StringIO(code_str).readline) - for tok_type, tok_text, start, end, _ in g: - if tok_type == tokenize.COMMENT: - comment = tok_text.replace(" ", "").strip() - if comment.startswith('#@hint:'): - flagtree_hints = comment[len('#@hint:'):].strip() - # Record the line number of the comment - line_num = start[0] - line_flagtree_hints[line_num] = flagtree_hints - - # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + line_flagtree_hints = flagtree_backend_specialization('maps_line_numbers_to_comment_hints', self) tree = ast.parse(self.src) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 assert isinstance(tree.body[0], ast.FunctionDef) - # Attach the line number to comment mapping to the function definition node - tree.body[0].line_flagtree_hints = line_flagtree_hints - + # flagtree backend specialization + flagtree_backend_specialization('attach_line_number_to_comment_mapping', tree, line_flagtree_hints) + return tree def __call__(self, *args, **kwargs): diff --git a/python/triton/testing.py b/python/triton/testing.py index b929ef22c..92d88baee 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,8 +1,6 @@ import functools import os import subprocess -import multiprocessing -import os import sys from contextlib import contextmanager from typing import Any, Dict, List @@ -114,10 +112,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m assert return_mode in ["min", "max", "mean", "median", "all"] import torch - enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() == 'npu' - if torch.npu.is_available() and enable_bench_npu: - avg_time = do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) - return _summarize_statistics(torch.tensor([avg_time], dtype=torch.float), quantiles, return_mode) + # flagtree backend specialization + from triton.runtime.driver import flagtree_backend_specialization + if flagtree_backend_specialization('is_do_bench_npu'): + return flagtree_backend_specialization('ext_do_bench_npu') di = runtime.driver.active.get_device_interface() @@ -164,99 +162,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) -def collect_files(base_dir): - import pandas as pd - for root, dirs, files in os.walk(base_dir): - for file in files: - if file != 'op_statistic.csv': - continue - target_file = os.path.join(root, file) - df = pd.read_csv(target_file) - triton_rows = df[df['OP Type'].str.startswith('triton', na=False)] - if not triton_rows.empty: - return triton_rows['Avg Time(us)'].values[0] - return float('inf') - return float('inf') - - -def collect_single(base_dir: str, key: str = None) -> float: - if not os.path.exists(base_dir): - return float('inf') - - import pandas as pd - for root, _, files in os.walk(base_dir): - for file in files: - if file != 'op_statistic.csv': - continue - target_file = os.path.join(root, file) - df = pd.read_csv(target_file) - if key is not None: - key_rows = df[df['OP Type'].str.startswith(key, na=False)] - if not key_rows.empty: - return key_rows['Avg Time(us)'].values[0] - return float('inf') - else: - # default: read the first row except header - return df.loc[0, 'Avg Time(us)'] - - return float('inf') - - -def do_bench_npu(fn, warmup=5, active=30, prof_dir=None, keep_res=False): - import torch - import torch_npu - from datetime import datetime, timezone - - # warmup kernel - fn() - torch.npu.synchronize() - - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - data_simplification=False - ) - skip_first = 1 - wait = 0 - repeat = 1 - total = skip_first + (wait + warmup + active) * repeat - - if prof_dir is not None: - torch_path = prof_dir - else: - process = multiprocessing.current_process() - pid = process.pid - process_name = process.name - timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") - base_path = os.path.join(runtime.cache.get_home_dir(), ".triton", "profile_results") - torch_path = os.path.join(base_path, f"prof_{timestamp}_{process_name}-{pid}") - with torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.NPU - ], - schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), - record_shapes=False, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config, - ) as prof: - for _ in range(total): - fn() - prof.step() - torch.npu.synchronize() - - time = collect_single(torch_path) - - if not keep_res: - import shutil - if os.path.exists(torch_path): - shutil.rmtree(torch_path) - - return time def assert_close(x, y, atol=None, rtol=None, err_msg=''): """ @@ -431,6 +336,7 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b ax.legend() ax.set_xlabel(bench.xlabel or first_x) ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) ax.set_xscale("log" if bench.x_log else "linear") ax.set_yscale("log" if bench.y_log else "linear") if show_plots: @@ -609,205 +515,7 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 return tflops -# Patch the triton language API here because triton's __init__.py -# import testing in the last stages. -from .language.tensor_descriptor import ( - tensor_descriptor, - tensor_descriptor_type, -) - -from .language.core_ext import ( - dot, - cast, - gather, - get_element, - insert_slice, - extract_slice, - trans, - __add__, - __radd__, - __sub__, - __rsub__, - __mul__, - __rmul__, - __lshift__, - __rshift__, - parallel, - compile_hint, - make_tensor_descriptor, - load_tensor_descriptor, - store_tensor_descriptor, - multibuffer, - sync_block_all, - sync_block_set, - sync_block_wait, - dtype_to_ir, - sort -) -from .language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 -from .language.math_ext import ( - umulhi, - exp, - exp2, - log, - log2, - cos, - sin, - sqrt, - sqrt_rn, - rsqrt, - div_rn, - erf, - tanh, - floor, - ceil, - _check_dtype, - fma, -) -from .language.semantic_ext import ( - arange, - floordiv, - atom_red_typechecking_impl, - atomic_cas, - atomic_max, - atomic_min, - _load_legacy, - maximum, - minimum, - mod, - invert, - logical_and, - logical_or, - not_, - and_, - or_, - xor_, - minus, - dot_scaled, -) -from . import language - -language.cast = cast -language.dot = dot -language.flip = flip -language.sigmoid = sigmoid -language.softmax = softmax -language.gather = gather -language.insert_slice = insert_slice -language.extract_slice = extract_slice -language.get_element = get_element -language.tensor.__add__ = __add__ -language.tensor.__radd__ = __radd__ -language.tensor.__sub__ = __sub__ -language.tensor.__rsub__ = __rsub__ -language.tensor.__mul__ = __mul__ -language.tensor.__rmul__ = __rmul__ -language.tensor.__lshift__ = __lshift__ -language.tensor.__rshift__ = __rshift__ -language.trans = trans -language.parallel = parallel -language.compile_hint = compile_hint -language.sort = sort -language.multibuffer = multibuffer -language.sync_block_all = sync_block_all -language.sync_block_set = sync_block_set -language.sync_block_wait = sync_block_wait -language.make_tensor_descriptor = make_tensor_descriptor -language.tensor_descriptor = tensor_descriptor -language.tensor_descriptor_type = tensor_descriptor_type -language.load_tensor_descriptor = load_tensor_descriptor -language.store_tensor_descriptor = store_tensor_descriptor - -language.semantic.arange = arange -language.semantic.floordiv = floordiv -language.semantic.atom_red_typechecking_impl = atom_red_typechecking_impl -language.semantic.atomic_cas = atomic_cas -language.semantic.atomic_max = atomic_max -language.semantic.atomic_min = atomic_min -language.semantic._load_legacy = _load_legacy -language.semantic.maximum = maximum -language.semantic.minimum = minimum -language.semantic.invert = invert -language.semantic.logical_and = logical_and -language.semantic.logical_or = logical_or -language.semantic.mod = mod -language.semantic.not_ = not_ -language.semantic.and_ = and_ -language.semantic.or_ = or_ -language.semantic.xor_ = xor_ -language.semantic.minus = minus -language.semantic.dot_scaled = dot_scaled - -language.umulhi = umulhi -language.exp = exp -language.exp2 = exp2 -language.log = log -language.log2 = log2 -language.cos = cos -language.sin = sin -language.sqrt = sqrt -language.sqrt_rn = sqrt_rn -language.rsqrt = rsqrt -language.div_rn = div_rn -language.erf = erf -language.tanh = tanh -language.floor = floor -language.ceil = ceil -language.core.dtype.to_ir = dtype_to_ir -language.fma = fma -language.math.umulhi = umulhi -language.math.exp = exp -language.math.exp2 = exp2 -language.math.log = log -language.math.log2 = log2 -language.math.cos = cos -language.math.sin = sin -language.math.sqrt = sqrt -language.math.sqrt_rn = sqrt_rn -language.math.rsqrt = rsqrt -language.math.div_rn = div_rn -language.math.erf = erf -language.math.tanh = tanh -language.math.floor = floor -language.math.ceil = ceil -language.math._check_dtype = _check_dtype -language.math.fma = fma -language.math.isnan = language.extra.ascend.libdevice.isnan -language.math.isinf = language.extra.ascend.libdevice.isinf -language.math.reciprocal = language.extra.ascend.libdevice.reciprocal -language.math.log1p = language.extra.ascend.libdevice.log1p -language.math.relu = language.extra.ascend.libdevice.relu -language.math.tan = language.extra.ascend.libdevice.tan -language.math.atan = language.extra.ascend.libdevice.atan -language.math.tanh = language.extra.ascend.libdevice.tanh -language.math.ilogb = language.extra.ascend.libdevice.ilogb -language.math.ldexp = language.extra.ascend.libdevice.ldexp -language.math.pow = language.extra.ascend.libdevice.pow -language.math.flip = language.extra.ascend.libdevice.flip -language.math.atan2 = language.extra.ascend.libdevice.atan2 -language.math.div_rz = language.extra.ascend.libdevice.div_rz -language.math.fmod = language.extra.ascend.libdevice.fmod -language.math.trunc = language.extra.ascend.libdevice.trunc -language.math.round = language.extra.ascend.libdevice.round -language.math.finitef = finitef -language.math.isfinited = isfinited -language.math.rint = rint -language.math.atan2 = atan2 -language.extra.ascend.libdevice.umulhi = language.math.umulhi -language.extra.ascend.libdevice.exp = language.math.exp -language.extra.ascend.libdevice.exp2 = language.math.exp2 -language.extra.ascend.libdevice.log = language.math.log -language.extra.ascend.libdevice.log2 = language.math.log2 -language.extra.ascend.libdevice.cos = language.math.cos -language.extra.ascend.libdevice.sin = language.math.sin -language.extra.ascend.libdevice.sqrt = language.math.sqrt -language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn -language.extra.ascend.libdevice.rsqrt = language.math.rsqrt -language.extra.ascend.libdevice.div_rn = language.math.div_rn -language.extra.ascend.libdevice.erf = language.math.erf -language.extra.ascend.libdevice.tanh = language.math.tanh -language.extra.ascend.libdevice.floor = language.math.floor -language.extra.ascend.libdevice.ceil = language.math.ceil -language.extra.ascend.libdevice.fdiv = language.math.fdiv -language.extra.ascend.libdevice.fma = language.math.fma -language.extra.ascend.libdevice.abs = language.math.abs + +# flagtree backend specialization +from triton.runtime.driver import flagtree_backend_specialization +flagtree_backend_specialization('patch_triton_language') diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py index 8ea0873a8..603bb84cc 100644 --- a/third_party/ascend/backend/driver.py +++ b/third_party/ascend/backend/driver.py @@ -105,6 +105,9 @@ class NPUDriver(DriverBase): def __init__(self): self.utils = NPUUtils() self.launcher_cls = NPULauncher + # flagtree backend specialization + from triton.backends.ascend import flagtree_backend_specialization + self.flagtree_backend_specialization = flagtree_backend_specialization super().__init__() @classmethod diff --git a/third_party/ascend/backend/flagtree_backend_specialization/__init__.py b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py new file mode 100644 index 000000000..9cce49ef4 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/__init__.py @@ -0,0 +1,55 @@ +from .triton.compiler.compiler import * +from .triton.compiler.errors import * +from .triton.compiler.code_generator import * +from .triton.runtime.jit import * +from .triton.runtime.autotuner import * +from .triton.language._utils import * +from .triton.testing import * + +__all__ = [ + # compiler.compiler + 'ext_ASTSource_attrs', + 'opt_ascend_compile_speed', + 'set_CompiledKernel_metadata_stream', + 'handle_compile_error', + 'is_CompiledKernel_getattribute_need_init_handles', + # compiler.code_generator + 'anno_CodeGenerator_visit_Assign', + 'ext_CodeGenerator_visit_Assign_hint_anno', + 'init_bind_sub_block', + 'is_visit_For_support_parallel', + 'set_bind_sub_block_when_parallel', + 'check_override_bind_sub_block', + 'forop_setattr_for_bind_sub_block', + 'need_repr_in_CodeGenerator_CompilationError', + # runtime.jit + 'is_set_stream_in_kwargs', + 'is_stream_option_deprecated', + 'ignore_params_in_JITFunction_run', + 'set_stream_from_kwargs', + 'check_grid_size', + 'explicit_load_kernel_library', + 'is_JITFunction_spec_attr', + 'get_JITFunction_spec_attr', + 'maps_line_numbers_to_comment_hints', + 'attach_line_number_to_comment_mapping', + # runtime.autotuner + 'set_Autotuner_auto_profile_dir', + 'has_spec_default_Autotuner_configs', + 'get_spec_default_Autotuner_configs', + 'ext_Autotuner_do_bench_MLIRCompilationError', + 'ext_Autotuner_profile', + 'set_Config_BiShengIR_options', + 'ext_Config_all_kwargs', + 'ext_Config_to_str', + 'new_AutoTilingTuner', + # language._utils + 'get_language_utils_IterableType_ObjPath', + 'get_triton_max_tensor_numel', + 'is_block_shape_check_power_of_two', + 'get_language_utils_BITWIDTH_DICT', + # testing + 'is_do_bench_npu', + 'ext_do_bench_npu', + 'patch_triton_language' +] diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py new file mode 100644 index 000000000..b90a9f99e --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/code_generator.py @@ -0,0 +1,63 @@ +def anno_CodeGenerator_visit_Assign(): + # flagtree: First, do normal assignment processing + return + +def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): + import ast + from triton.compiler.code_generator import _is_triton_value + # flagtree: After normal processing, check if we need to add hint annotation + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a tl.load call with dot_pad_only_k hint + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and + isinstance(node.value, ast.Call) and + isinstance(node.value.func, ast.Attribute) and + isinstance(node.value.func.value, ast.Name) and + node.value.func.value.id == 'tl' and + node.value.func.attr == 'load'): + + # Add hint annotation to the loaded tensor(s) + for name, value in zip(names, values): + if _is_triton_value(value): + # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") + # Create hint annotation + hint_val = code_generator.builder.get_unit_attr() + code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + +def init_bind_sub_block(): + return None + +def is_visit_For_support_parallel(): + return True + +def set_bind_sub_block_when_parallel(IteratorClass, iterator, bind_sub_block): + import triton.language as language + if (IteratorClass is language.parallel): + return iterator.bind_sub_block + return bind_sub_block + +def check_override_bind_sub_block(code_generator, node, bind_sub_block): + # flagtree: After normal processing, check if we need to override bind_sub_block + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a range/for loop with bind_sub_block hint + if flagtree_hints and 'bind_sub_block' in flagtree_hints: + return True + # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") + return bind_sub_block + +def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): + for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) + +def need_repr_in_CodeGenerator_CompilationError(): + return True diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py new file mode 100644 index 000000000..9bf881463 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/compiler.py @@ -0,0 +1,32 @@ +def ext_ASTSource_attrs(ast_source): + from triton.backends.ascend.compiler import AscendAttrsDescriptor + if ast_source.attrs is None: + ast_source.attrs = AscendAttrsDescriptor() + +def opt_ascend_compile_speed(file_name, metadata_path, fn_cache_manager): + import os + compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1') + if (compile_speed_opt): + ttir_path = f"{file_name}.ttir" + if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)): + # Already compile once but failed. So directly return + raise Exception("already failed once") + +def set_CompiledKernel_metadata_stream(compiled_kernel, stream): + if stream is None: + return stream + return compiled_kernel.metadata.stream + +def handle_compile_error(e, ext): + from triton.compiler.errors import MLIRCompilationError + if (ext == "ttadapter"): + stage_name = "ConvertTritonIRToLinalgIR" + elif (ext == "npubin"): + stage_name = "ConvertLinalgRToBinary" + else: + stage_name = "MLIRCompile" + error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) + raise MLIRCompilationError(stage_name, error_detail) + +def is_CompiledKernel_getattribute_need_init_handles(): + return False diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py new file mode 100644 index 000000000..b1ef43a3b --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/compiler/errors.py @@ -0,0 +1,20 @@ +import importlib.util +import sys +from typing import Optional +from triton.compiler.errors import TritonError + +class MLIRCompilationError(TritonError): + def __init__(self, stage_name: Optional[str], message: Optional[str] = None): + self.stage_name = stage_name + self.message = f"\n" \ + f"{self.format_line_delim('[ERROR][Triton][BEG]')}" \ + f"[{self.stage_name}] encounters error:\n" \ + f"{self.filter_message(message)}" \ + f"{self.format_line_delim('[ERROR][Triton][END]')}" + def __str__(self): + return self.message + def filter_message(self, message): + # Content starting from "Stack dump without symbol names" means nothing to the users + return message.split("Stack dump without symbol names")[0] + def format_line_delim(self, keyword): + return f"///------------------{keyword}------------------\n" diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py new file mode 100644 index 000000000..cfbe4681a --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/language/_utils.py @@ -0,0 +1,41 @@ +from typing import TYPE_CHECKING, Any, Union, Dict + +def get_language_utils_IterableType_ObjPath(): + from triton.language import core + IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] + ObjPath = tuple[int, ...] + return IterableType, ObjPath + + +TRITON_MAX_TENSOR_NUMEL = 1048576 + +def get_triton_max_tensor_numel(): + return TRITON_MAX_TENSOR_NUMEL + + +def is_block_shape_check_power_of_two(): + return False + + +BITWIDTH_DICT: Dict[str, int] = { + **{f"u{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"i{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"fp{n}": n + for n in (16, 32, 64)}, + **{f"fp8{suffix}": 8 + for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")}, + "bf16": 16, + "void": 0, +} + + +def get_language_utils_BITWIDTH_DICT(): + return BITWIDTH_DICT + + +def get_primitive_bitwidth(dtype: str) -> int: + return BITWIDTH_DICT[dtype] + + diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py new file mode 100644 index 000000000..545a9e879 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/autotuner.py @@ -0,0 +1,99 @@ +def set_Autotuner_auto_profile_dir(autotuner, auto_profile_dir): + autotuner.auto_profile_dir = auto_profile_dir + +def has_spec_default_Autotuner_configs(): + return True + +def get_spec_default_Autotuner_configs(): + from triton.runtime.autotuner import Config + return Config({}) + +def ext_Autotuner_do_bench_MLIRCompilationError(exception_types): + from triton.compiler.errors import MLIRCompilationError + return (MLIRCompilationError) + +def _profile(autotuner, *args, config, **meta): + from triton.testing import do_bench_npu + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**autotuner.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + autotuner.pre_hook(full_nargs) + try: + autotuner.fn.run( + *args, + **current, + ) + except Exception as e: + try: + autotuner.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `autotuner.fn.run` + raise + + autotuner.post_hook(full_nargs, exception=None) + + do_bench_npu( + kernel_call, prof_dir=autotuner.auto_profile_dir, keep_res=True + ) + +def ext_Autotuner_profile(autotuner, used_cached_result, args, kwargs): + if not used_cached_result and autotuner.auto_profile_dir is not None: + _profile(*args, config=autotuner.best_config, **kwargs) + +def set_Config_BiShengIR_options(config, bishengir_options): + # BiShengIR Options allowed for autotune + config.multibuffer = bishengir_options.get("multibuffer", None) # Compiler Default True + config.unit_flag = bishengir_options.get("unit_flag", None) # Compiler Default False + config.limit_auto_multi_buffer_only_for_local_buffer = bishengir_options.get("limit_auto_multi_buffer_only_for_local_buffer", None) # Compiler Default False + config.limit_auto_multi_buffer_of_local_buffer = bishengir_options.get("limit_auto_multi_buffer_of_local_buffer", None) # Compiler Default no-limit + config.set_workspace_multibuffer = bishengir_options.get("set_workspace_multibuffer", None) # Compiler Default 1 + config.enable_hivm_auto_cv_balance = bishengir_options.get("enable_hivm_auto_cv_balance", None) # Compiler Default True + config.tile_mix_vector_loop = bishengir_options.get("tile_mix_vector_loop", None) # Compiler Default 1 + config.tile_mix_cube_loop = bishengir_options.get("tile_mix_cube_loop", None) # Compiler Default 1 + +def ext_Config_all_kwargs(config): + return ( + ("multibuffer", config.multibuffer), + ("enable_hivm_auto_cv_balance", config.enable_hivm_auto_cv_balance), + ("unit_flag", config.unit_flag), + ("limit_auto_multi_buffer_only_for_local_buffer", \ + config.limit_auto_multi_buffer_only_for_local_buffer), + ("limit_auto_multi_buffer_of_local_buffer", config.limit_auto_multi_buffer_of_local_buffer), + ("set_workspace_multibuffer", config.set_workspace_multibuffer), + ("tile_mix_vector_loop", config.tile_mix_vector_loop), + ("tile_mix_cube_loop", config.tile_mix_cube_loop) + ) + +def ext_Config_to_str(res, config): + res.append(f"multibuffer: {config.multibuffer}") + res.append(f"enable_hivm_auto_cv_balance: {config.enable_hivm_auto_cv_balance}") + res.append(f"unit_flag: {config.unit_flag}") + res.append(f"limit_auto_multi_buffer_only_for_local_buffer: \ + {config.limit_auto_multi_buffer_only_for_local_buffer}") + res.append(f"limit_auto_multi_buffer_of_local_buffer: {config.limit_auto_multi_buffer_of_local_buffer}") + res.append(f"set_workspace_multibuffer: {config.set_workspace_multibuffer}") + res.append(f"tile_mix_vector_loop: {config.tile_mix_vector_loop}") + res.append(f"tile_mix_cube_loop: {config.tile_mix_cube_loop}") + +def new_AutoTilingTuner(fn, configs, key, reset_to_zero, restore_value, pre_hook, + post_hook, prune_configs_by, warmup, rep, + use_cuda_graph, do_bench, auto_profile_dir, + split_params, tiling_params, low_dims, + dual_reduction, persistent_reduction): + from triton.runtime.autotiling_tuner import AutoTilingTuner + return AutoTilingTuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir, + split_params=split_params, tiling_params=tiling_params, low_dims=low_dims, + dual_reduction=dual_reduction, persistent_reduction=persistent_reduction) diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py new file mode 100644 index 000000000..5dd119f27 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/runtime/jit.py @@ -0,0 +1,64 @@ +def is_set_stream_in_kwargs(kwargs): + return True if ('stream' not in kwargs.keys()) else False + +def is_stream_option_deprecated(): + return False + +def ignore_params_in_JITFunction_run(kwargs, excess_kwargs): + ignor_params = ["debug", "sanitize_overflow", "llvm_version", "kernel_name", \ + "allowed_dot_input_precisions", "multibuffer", "stream"] + not_work_params = [] + for k in kwargs: + if k in ignor_params: + continue + elif k in excess_kwargs: + not_work_params.append(k) + if len(not_work_params) != 0: + print("[WARNING] Please DO NOT tune args {}!".format(not_work_params)) + +def set_stream_from_kwargs(kwargs, stream): + if ('stream' in kwargs.keys()): + return kwargs["stream"] + return stream + +def check_grid_size(grid_0, grid_1, grid_2): + import os + grid_all_size = grid_0 * grid_1 * grid_2 + if os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "0") == "0": + if grid_all_size > 65535: + raise RuntimeError("grid should be less than 65536! You can try \"export TRITON_ALL_BLOCKS_PARALLEL=1\" to avoid this problem.") + +def explicit_load_kernel_library(kernel): + # explicitly define run method and load kernel binary + kernel._init_handles() + +def is_JITFunction_spec_attr(): + return True + +def get_JITFunction_spec_attr(deserialized_obj): + from triton.backends.ascend.compiler import AscendAttrsDescriptor + return AscendAttrsDescriptor.from_dict(deserialized_obj['attrs']) + +def maps_line_numbers_to_comment_hints(jit_fn): + import tokenize + from io import StringIO + # Maps line numbers to comment hints + line_flagtree_hints = {} + code_str = jit_fn.src + g = tokenize.generate_tokens(StringIO(code_str).readline) + for tok_type, tok_text, start, end, _ in g: + if tok_type == tokenize.COMMENT: + comment = tok_text.replace(" ", "").strip() + if comment.startswith('#@hint:'): + flagtree_hints = comment[len('#@hint:'):].strip() + # Record the line number of the comment + line_num = start[0] + line_flagtree_hints[line_num] = flagtree_hints + + # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + + return line_flagtree_hints + +def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): + # Attach the line number to comment mapping to the function definition node + tree.body[0].line_flagtree_hints = line_flagtree_hints diff --git a/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py b/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py new file mode 100644 index 000000000..15c856d9a --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/triton/testing.py @@ -0,0 +1,316 @@ +import torch +import os + +def is_do_bench_npu(): + enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() == 'npu' + if torch.npu.is_available() and enable_bench_npu: + return True + return False + + +def collect_files(base_dir): + import pandas as pd + for root, dirs, files in os.walk(base_dir): + for file in files: + if file != 'op_statistic.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + triton_rows = df[df['OP Type'].str.startswith('triton', na=False)] + if not triton_rows.empty: + return triton_rows['Avg Time(us)'].values[0] + return float('inf') + return float('inf') + + +def collect_single(base_dir: str, key: str = None) -> float: + if not os.path.exists(base_dir): + return float('inf') + + import pandas as pd + for root, _, files in os.walk(base_dir): + for file in files: + if file != 'op_statistic.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + if key is not None: + key_rows = df[df['OP Type'].str.startswith(key, na=False)] + if not key_rows.empty: + return key_rows['Avg Time(us)'].values[0] + return float('inf') + else: + # default: read the first row except header + return df.loc[0, 'Avg Time(us)'] + + return float('inf') + + +def do_bench_npu(fn, warmup=5, active=30, prof_dir=None, keep_res=False): + import torch_npu + import multiprocessing + from triton import runtime + from datetime import datetime, timezone + + # warmup kernel + fn() + torch.npu.synchronize() + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + skip_first = 1 + wait = 0 + repeat = 1 + total = skip_first + (wait + warmup + active) * repeat + + if prof_dir is not None: + torch_path = prof_dir + else: + process = multiprocessing.current_process() + pid = process.pid + process_name = process.name + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") + base_path = os.path.join(runtime.cache.get_home_dir(), ".triton", "profile_results") + torch_path = os.path.join(base_path, f"prof_{timestamp}_{process_name}-{pid}") + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config, + ) as prof: + for _ in range(total): + fn() + prof.step() + torch.npu.synchronize() + + time = collect_single(torch_path) + + if not keep_res: + import shutil + if os.path.exists(torch_path): + shutil.rmtree(torch_path) + + return time + + +def ext_do_bench_npu(fn, warmup, rep, quantiles, return_mode): + import torch + from triton.testing import _summarize_statistics + avg_time = do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) + return _summarize_statistics(torch.tensor([avg_time], dtype=torch.float), quantiles, return_mode) + + +def patch_triton_language(): + # Patch the triton language API here because triton's __init__.py + # import testing in the last stages. + from triton.language.tensor_descriptor import ( + tensor_descriptor, + tensor_descriptor_type, + ) + + from triton.language.core_ext import ( + dot, + cast, + gather, + get_element, + insert_slice, + extract_slice, + trans, + __add__, + __radd__, + __sub__, + __rsub__, + __mul__, + __rmul__, + __lshift__, + __rshift__, + parallel, + compile_hint, + make_tensor_descriptor, + load_tensor_descriptor, + store_tensor_descriptor, + multibuffer, + sync_block_all, + sync_block_set, + sync_block_wait, + dtype_to_ir, + sort + ) + from triton.language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 + from triton.language.math_ext import ( + umulhi, + exp, + exp2, + log, + log2, + cos, + sin, + sqrt, + sqrt_rn, + rsqrt, + div_rn, + erf, + tanh, + floor, + ceil, + _check_dtype, + fma, + ) + from triton.language.semantic_ext import ( + arange, + floordiv, + atom_red_typechecking_impl, + atomic_cas, + atomic_max, + atomic_min, + _load_legacy, + maximum, + minimum, + mod, + invert, + logical_and, + logical_or, + not_, + and_, + or_, + xor_, + minus, + dot_scaled, + ) + from triton import language + + language.cast = cast + language.dot = dot + language.flip = flip + language.sigmoid = sigmoid + language.softmax = softmax + language.gather = gather + language.insert_slice = insert_slice + language.extract_slice = extract_slice + language.get_element = get_element + language.tensor.__add__ = __add__ + language.tensor.__radd__ = __radd__ + language.tensor.__sub__ = __sub__ + language.tensor.__rsub__ = __rsub__ + language.tensor.__mul__ = __mul__ + language.tensor.__rmul__ = __rmul__ + language.tensor.__lshift__ = __lshift__ + language.tensor.__rshift__ = __rshift__ + language.trans = trans + language.parallel = parallel + language.compile_hint = compile_hint + language.sort = sort + language.multibuffer = multibuffer + language.sync_block_all = sync_block_all + language.sync_block_set = sync_block_set + language.sync_block_wait = sync_block_wait + language.make_tensor_descriptor = make_tensor_descriptor + language.tensor_descriptor = tensor_descriptor + language.tensor_descriptor_type = tensor_descriptor_type + language.load_tensor_descriptor = load_tensor_descriptor + language.store_tensor_descriptor = store_tensor_descriptor + + language.semantic.arange = arange + language.semantic.floordiv = floordiv + language.semantic.atom_red_typechecking_impl = atom_red_typechecking_impl + language.semantic.atomic_cas = atomic_cas + language.semantic.atomic_max = atomic_max + language.semantic.atomic_min = atomic_min + language.semantic._load_legacy = _load_legacy + language.semantic.maximum = maximum + language.semantic.minimum = minimum + language.semantic.invert = invert + language.semantic.logical_and = logical_and + language.semantic.logical_or = logical_or + language.semantic.mod = mod + language.semantic.not_ = not_ + language.semantic.and_ = and_ + language.semantic.or_ = or_ + language.semantic.xor_ = xor_ + language.semantic.minus = minus + language.semantic.dot_scaled = dot_scaled + + language.umulhi = umulhi + language.exp = exp + language.exp2 = exp2 + language.log = log + language.log2 = log2 + language.cos = cos + language.sin = sin + language.sqrt = sqrt + language.sqrt_rn = sqrt_rn + language.rsqrt = rsqrt + language.div_rn = div_rn + language.erf = erf + language.tanh = tanh + language.floor = floor + language.ceil = ceil + language.core.dtype.to_ir = dtype_to_ir + language.fma = fma + language.math.umulhi = umulhi + language.math.exp = exp + language.math.exp2 = exp2 + language.math.log = log + language.math.log2 = log2 + language.math.cos = cos + language.math.sin = sin + language.math.sqrt = sqrt + language.math.sqrt_rn = sqrt_rn + language.math.rsqrt = rsqrt + language.math.div_rn = div_rn + language.math.erf = erf + language.math.tanh = tanh + language.math.floor = floor + language.math.ceil = ceil + language.math._check_dtype = _check_dtype + language.math.fma = fma + language.math.isnan = language.extra.ascend.libdevice.isnan + language.math.isinf = language.extra.ascend.libdevice.isinf + language.math.reciprocal = language.extra.ascend.libdevice.reciprocal + language.math.log1p = language.extra.ascend.libdevice.log1p + language.math.relu = language.extra.ascend.libdevice.relu + language.math.tan = language.extra.ascend.libdevice.tan + language.math.atan = language.extra.ascend.libdevice.atan + language.math.tanh = language.extra.ascend.libdevice.tanh + language.math.ilogb = language.extra.ascend.libdevice.ilogb + language.math.ldexp = language.extra.ascend.libdevice.ldexp + language.math.pow = language.extra.ascend.libdevice.pow + language.math.flip = language.extra.ascend.libdevice.flip + language.math.atan2 = language.extra.ascend.libdevice.atan2 + language.math.div_rz = language.extra.ascend.libdevice.div_rz + language.math.fmod = language.extra.ascend.libdevice.fmod + language.math.trunc = language.extra.ascend.libdevice.trunc + language.math.round = language.extra.ascend.libdevice.round + language.math.finitef = finitef + language.math.isfinited = isfinited + language.math.rint = rint + language.math.atan2 = atan2 + language.extra.ascend.libdevice.umulhi = language.math.umulhi + language.extra.ascend.libdevice.exp = language.math.exp + language.extra.ascend.libdevice.exp2 = language.math.exp2 + language.extra.ascend.libdevice.log = language.math.log + language.extra.ascend.libdevice.log2 = language.math.log2 + language.extra.ascend.libdevice.cos = language.math.cos + language.extra.ascend.libdevice.sin = language.math.sin + language.extra.ascend.libdevice.sqrt = language.math.sqrt + language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn + language.extra.ascend.libdevice.rsqrt = language.math.rsqrt + language.extra.ascend.libdevice.div_rn = language.math.div_rn + language.extra.ascend.libdevice.erf = language.math.erf + language.extra.ascend.libdevice.tanh = language.math.tanh + language.extra.ascend.libdevice.floor = language.math.floor + language.extra.ascend.libdevice.ceil = language.math.ceil + language.extra.ascend.libdevice.fdiv = language.math.fdiv + language.extra.ascend.libdevice.fma = language.math.fma + language.extra.ascend.libdevice.abs = language.math.abs \ No newline at end of file