Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 46 additions & 46 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,12 @@ def visit_Return(self, node: ast.Return) -> bool:

def visit_Assign(self, node: ast.Assign) -> bool:
# There couldn't be an early return
# x = ...
return False

def visit_AugAssign(self, node: ast.AugAssign) -> bool:
# There couldn't be an early return
# x += ...
return False

def visit_Module(self, node: ast.Module) -> bool:
Expand All @@ -168,6 +170,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> bool:
return self._visit_stmts(node.body)

def visit_If(self, node: ast.If) -> bool:
# TODO: optimize the following case in which we actually don't have
# a return when static_cond is false:
# if dynamic_cond
# if static_cond
# func_with_return
# else
# func_without_return
ret = self._visit_stmts(node.body)
if node.orelse:
ret = ret or self._visit_stmts(node.orelse)
Expand All @@ -192,6 +201,9 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
self.begin_line = begin_line - 1
self.builder.set_loc(file_name, begin_line, 0)
self.builder.options = options
# dict of functions provided by the backend. Below are the list of possible functions:
# Convert custom types not natively supported on HW.
# convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None)
self.builder.codegen_fns = codegen_fns
self.builder.module_map = {} if module_map is None else module_map
self.module = self.builder.create_module() if module is None else module
Expand Down Expand Up @@ -474,7 +486,10 @@ def visit_AnnAssign(self, node):
return self.visit_Assign(node)

def visit_Assign(self, node):
# flagtree: First, do normal assignment processing
# flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
flagtree_backend_specialization("anno_CodeGenerator_visit_Assign")

_names = []
if isinstance(node, ast.AnnAssign):
_names += [self.visit(node.target)]
Expand All @@ -498,30 +513,10 @@ def visit_Assign(self, node):
not isinstance(value, native_nontensor_types):
value = language.semantic.to_tensor(value, self.builder)
self.set_value(name, value)

# flagtree: After normal processing, check if we need to add hint annotation
if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'):
line_num = node.lineno
# TODO: reparse needed in case we need to deal with complex cases, will be redesigned later
function_def = self.jit_fn.parse()
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
flagtree_hints = line_flagtree_hints.get(line_num)

# Check if this is a tl.load call with dot_pad_only_k hint
if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and
isinstance(node.value, ast.Call) and
isinstance(node.value.func, ast.Attribute) and
isinstance(node.value.func.value, ast.Name) and
node.value.func.value.id == 'tl' and
node.value.func.attr == 'load'):

# Add hint annotation to the loaded tensor(s)
for name, value in zip(names, values):
if _is_triton_value(value):
# print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}")
# Create hint annotation
hint_val = self.builder.get_unit_attr()
self.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val)

#flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
flagtree_backend_specialization("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values)

def visit_AugAssign(self, node):
name = node.target.id
Expand Down Expand Up @@ -828,6 +823,8 @@ def visit_While(self, node):
liveins, insert_block = sr
ip, last_loc = self._get_insertion_point_and_loc()

# loop body (the after region)
# loop_block = self.builder.create_block()
dummy = self.builder.create_block()
self.builder.set_insertion_point_to_start(dummy)
self.scf_stack.append(node)
Expand Down Expand Up @@ -921,8 +918,11 @@ def visit_For(self, node):
return
num_stages = None
loop_unroll_factor = None
bind_sub_block = None
if IteratorClass in [language.range, language.parallel]:

#flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
bind_sub_block = flagtree_backend_specialization("init_bind_sub_block")
if IteratorClass in [language.range] + ([language.parallel] if flagtree_backend_specialization("is_visit_For_support_parallel") else []):
iterator = IteratorClass(*iter_args, **iter_kwargs)
# visit iterator arguments
# note: only `range` iterator is supported now
Expand All @@ -932,8 +932,9 @@ def visit_For(self, node):
step = iterator.step
num_stages = iterator.num_stages
loop_unroll_factor = iterator.loop_unroll_factor
if (IteratorClass is language.parallel):
bind_sub_block = iterator.bind_sub_block

#flagtree backend specialization
bind_sub_block = flagtree_backend_specialization("set_bind_sub_block_when_parallel", IteratorClass, iterator, bind_sub_block)
elif IteratorClass is range:
# visit iterator arguments
# note: only `range` iterator is supported now
Expand All @@ -943,20 +944,10 @@ def visit_For(self, node):
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
else:
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')

# flagtree: After normal processing, check if we need to override bind_sub_block
if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'):
line_num = node.lineno
# TODO: reparse needed in case we need to deal with complex cases, will be redesigned later
function_def = self.jit_fn.parse()
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
flagtree_hints = line_flagtree_hints.get(line_num)

# Check if this is a range/for loop with bind_sub_block hint
if flagtree_hints and 'bind_sub_block' in flagtree_hints:
bind_sub_block = True
# print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}")


#flagtree backend specialization
bind_sub_block = flagtree_backend_specialization("check_override_bind_sub_block", self, node, bind_sub_block)

# handle negative constant step (not supported by scf.for in MLIR)
negative_step = False
if _is_constexpr(step) and step.value < 0:
Expand Down Expand Up @@ -1021,7 +1012,8 @@ def visit_For(self, node):
if loop_unroll_factor is not None:
for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
if (bind_sub_block is not None) and bind_sub_block:
for_op.set_attr("bind_sub_block", self.builder.get_bool_attr(bind_sub_block))
#flagtree backend specialization
flagtree_backend_specialization("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block)

self.scf_stack.append(node)
self.builder.set_insertion_point_to_start(for_op.get_body(0))
Expand Down Expand Up @@ -1105,7 +1097,11 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
generator.visit(fn.parse())
except Exception as e:
# Wrap the error in the callee with the location of the call.
raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from e

#flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
raise CompilationError(self.jit_fn.src, self.cur_node,
repr(e) if flagtree_backend_specialization('need_repr_in_CodeGenerator_CompilationError') else None) from e

callee_ret_type = generator.ret_type
self.function_ret_types[fn_name] = callee_ret_type
Expand Down Expand Up @@ -1149,7 +1145,11 @@ def visit_Call(self, node):
# itself). But when calling a function, we raise as `from e` to
# preserve the traceback of the original error, which may e.g.
# be in core.py.
raise CompilationError(self.jit_fn.src, node, repr(e)) from e

#flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
raise CompilationError(self.jit_fn.src, node,
repr(e) if flagtree_backend_specialization('need_repr_in_CodeGenerator_CompilationError') else None) from e

if fn in self.builtin_namespace.values():
args = map(_unwrap_if_constexpr, args)
Expand Down
46 changes: 23 additions & 23 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
from .._C.libtriton import get_cache_invalidating_env_vars, ir
from ..backends import backends
from ..backends.compiler import GPUTarget, AttrsDescriptor
from ..backends.ascend.compiler import AscendAttrsDescriptor
from .. import __version__
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
from ..runtime.driver import driver
from ..tools.disasm import get_sass
# TODO: this shouldn't be here
from .code_generator import ast_to_ttir
from .errors import MLIRCompilationError
from pathlib import Path
import re
import functools
Expand Down Expand Up @@ -87,8 +85,9 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None:
for k in self.constants.keys():
if not isinstance(k, str):
raise TypeError("Constants keys must be string")
if self.attrs is None:
self.attrs = AscendAttrsDescriptor()
# flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
flagtree_backend_specialization("ext_ASTSource_attrs", self)

def hash(self):
sorted_sig = [v for k, v in sorted(self.signature.items())]
Expand Down Expand Up @@ -252,12 +251,11 @@ def compile(src, target=None, options=None):
# cache hit!
metadata = json.loads(Path(metadata_path).read_text())
return CompiledKernel(src, metadata_group, hash)
compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1')
if (compile_speed_opt):
ttir_path = f"{file_name}.ttir"
if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)):
# Already compile once but failed. So directly return
raise Exception("already failed once")

# flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
flagtree_backend_specialization("opt_ascend_compile_speed", file_name, metadata_path, fn_cache_manager)

# initialize metadata
metadata = {
"hash": hash,
Expand Down Expand Up @@ -287,14 +285,10 @@ def compile(src, target=None, options=None):
try:
next_module = compile_ir(module, metadata)
except Exception as e:
if (ext == "ttadapter"):
stage_name = "ConvertTritonIRToLinalgIR"
elif (ext == "npubin"):
stage_name = "ConvertLinalgRToBinary"
else:
stage_name = "MLIRCompile"
error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e)
raise MLIRCompilationError(stage_name, error_detail)
# flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
flagtree_backend_specialization("handle_compile_error", e, ext)

ir_filename = f"{file_name}.{ext}"
if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None):
print(f"\nOverriding kernel with file {full_name}")
Expand Down Expand Up @@ -406,9 +400,12 @@ def _init_handles(self):
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
self.name, self.kernel, self.metadata.shared, device)

# This mechanism introduces heavy runtime overhead.
# Commenting __getattribute__ requires explicitly calling _init_handles()
def __getattribute__(self, name):
# flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
if name == 'run' and flagtree_backend_specialization("is_CompiledKernel_getattribute_need_init_handles"):
self._init_handles()
return super().__getattribute__(name)

def launch_metadata(self, grid, stream, *args):
if CompiledKernel.launch_enter_hook is None:
Expand All @@ -431,8 +428,11 @@ def __getitem__(self, grid):
self._init_handles()

def runner(*args, stream=None):
if stream is None:
stream = self.metadata.stream

# flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
flagtree_backend_specialization("set_CompiledKernel_metadata_stream", self, stream)

if stream is None:
device = driver.active.get_current_device()
stream = driver.active.get_current_stream(device)
Expand Down
18 changes: 3 additions & 15 deletions python/triton/compiler/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,6 @@ class UnsupportedLanguageConstruct(CompilationError):
pass


class MLIRCompilationError(TritonError):
def __init__(self, stage_name: Optional[str], message: Optional[str] = None):
self.stage_name = stage_name
self.message = f"\n" \
f"{self.format_line_delim('[ERROR][Triton][BEG]')}" \
f"[{self.stage_name}] encounters error:\n" \
f"{self.filter_message(message)}" \
f"{self.format_line_delim('[ERROR][Triton][END]')}"
def __str__(self):
return self.message
def filter_message(self, message):
# Content starting from "Stack dump without symbol names" means nothing to the users
return message.split("Stack dump without symbol names")[0]
def format_line_delim(self, keyword):
return f"///------------------{keyword}------------------\n"
# flagtree backend specialization
from triton.runtime.driver import flagtree_backend_class_specialization
MLIRCompilationError = flagtree_backend_class_specialization("MLIRCompilationError")
39 changes: 20 additions & 19 deletions python/triton/language/_utils.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,40 @@
from __future__ import annotations

from typing import List, TYPE_CHECKING, Any, Union, Dict
from typing import List

# flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .language import core
IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type]
ObjPath = tuple[int, ...]
IterableType, ObjPath = flagtree_backend_specialization('get_language_utils_IterableType_ObjPath')


TRITON_MAX_TENSOR_NUMEL = flagtree_backend_specialization('get_triton_max_tensor_numel')


def is_power_of_two(x):
return (x & (x - 1)) == 0

TRITON_MAX_TENSOR_NUMEL = 1048576

def validate_block_shape(shape: List[int]):
numel = 1
for i, d in enumerate(shape):
if not isinstance(d, int):
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
if flagtree_backend_specialization('is_block_shape_check_power_of_two') and not is_power_of_two(d):
raise ValueError(f"Shape element {i} must be a power of 2")
numel *= d

if numel > TRITON_MAX_TENSOR_NUMEL:
raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
return numel


BITWIDTH_DICT: Dict[str, int] = {
**{f"u{n}": n
for n in (1, 8, 16, 32, 64)},
**{f"i{n}": n
for n in (1, 8, 16, 32, 64)},
**{f"fp{n}": n
for n in (16, 32, 64)},
**{f"fp8{suffix}": 8
for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")},
"bf16": 16,
"void": 0,
}
# flagtree backend specialization
from triton.runtime.driver import flagtree_backend_specialization
BITWIDTH_DICT = flagtree_backend_specialization('get_language_utils_BITWIDTH_DICT')


def get_primitive_bitwidth(dtype: str) -> int:
return BITWIDTH_DICT[dtype]
# flagtree backend specialization
from triton.runtime.driver import flagtree_backend_func_specialization
get_primitive_bitwidth = flagtree_backend_func_specialization("get_primitive_bitwidth")
Loading