Skip to content

Commit 518b26e

Browse files
authored
[Frontend] Separate tensor and ir value into two different concepts (#4854)
Currently in the frontend all IR values are tensors, but for the TMA work I would like tensor descriptor to be it's own distinct type. This just adds a new base class representing a generic IR value and makes tensor inherit from it.
1 parent 3a9ddea commit 518b26e

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

python/triton/compiler/code_generator.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .. import language
1010
from .._C.libtriton import ir
1111
from ..language import constexpr, tensor, str_to_ty
12-
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type
12+
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value
1313
from ..runtime.jit import _normalize_ty, get_jit_fn_file_line
1414
# ideally we wouldn't need any runtime component
1515
from ..runtime import JITFunction
@@ -47,6 +47,10 @@ def mangle_fn(name, arg_tys, constants):
4747
return ret
4848

4949

50+
def _is_triton_value(o: Any) -> bool:
51+
return isinstance(o, _value)
52+
53+
5054
def _is_triton_tensor(o: Any) -> bool:
5155
return isinstance(o, tensor)
5256

@@ -501,7 +505,7 @@ def visit_Assign(self, node):
501505
# by default, constexpr are assigned into python variable
502506
value = _unwrap_if_constexpr(value)
503507
if value is not None and \
504-
not _is_triton_tensor(value) and \
508+
not _is_triton_value(value) and \
505509
not isinstance(value, native_nontensor_types):
506510
value = language.semantic.to_tensor(value, self.builder)
507511
self.set_value(name, value)
@@ -802,6 +806,15 @@ def visit_UnaryOp(self, node):
802806
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
803807
}
804808

809+
def _verify_loop_carried_variable(self, name, loop_val, live_val):
810+
assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop'
811+
assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop'
812+
assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type'
813+
assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
814+
f'Loop-carried variable {name} has initial type {live_val.type} '\
815+
f'but is re-assigned to {loop_val.type} in loop! '\
816+
f'Please make sure that the type stays consistent.'
817+
805818
def visit_While(self, node):
806819
with enter_sub_region(self) as sr:
807820
liveins, insert_block = sr
@@ -824,17 +837,14 @@ def visit_While(self, node):
824837
for name in loop_defs:
825838
if name in liveins:
826839
# We should not def new constexpr
827-
assert _is_triton_tensor(loop_defs[name]), f'cannot reassign constxpr {name} in the loop'
828-
assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop'
829-
assert loop_defs[name].type == liveins[name].type, \
830-
f'Loop-carried variable {name} has initial type {liveins[name].type} '\
831-
f'but is re-assigned to {loop_defs[name].type} in loop! '\
832-
f'Please make sure that the type stays consistent.'
840+
loop_val = loop_defs[name]
841+
live_val = liveins[name]
842+
self._verify_loop_carried_variable(name, loop_val, live_val)
833843

834844
# these are loop-carried values
835845
names.append(name)
836-
ret_types.append(loop_defs[name].type)
837-
init_args.append(liveins[name])
846+
ret_types.append(loop_val.type)
847+
init_args.append(live_val)
838848

839849
self._set_insertion_point_and_loc(ip, last_loc)
840850
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
@@ -972,16 +982,13 @@ def visit_For(self, node):
972982
names = []
973983
for name in self.local_defs:
974984
if name in liveins:
975-
assert _is_triton_tensor(self.local_defs[name]), f'cannot reassign constxpr {name} in the loop'
976-
assert _is_triton_tensor(liveins[name]), f'cannot reassign constxpr {name} in the loop'
977-
assert self.local_defs[name].type == liveins[name].type, \
978-
f'Loop-carried variable {name} has initial type {liveins[name].type} '\
979-
f'but is re-assigned to {self.local_defs[name].type} in loop! '\
980-
f'Please make sure that the type stays consistent.'
985+
loop_val = self.local_defs[name]
986+
live_val = liveins[name]
987+
self._verify_loop_carried_variable(name, loop_val, live_val)
981988

982989
names.append(name)
983-
init_args.append(language.semantic.to_tensor(liveins[name], self.builder))
984-
yields.append(language.semantic.to_tensor(self.local_defs[name], self.builder))
990+
init_args.append(live_val)
991+
yields.append(loop_val)
985992

986993
# create ForOp
987994
self._set_insertion_point_and_loc(ip, last_loc)
@@ -1051,7 +1058,7 @@ def visit_Assert(self, node) -> Any:
10511058
def call_JitFunction(self, fn: JITFunction, args, kwargs):
10521059
args = inspect.getcallargs(fn.fn, *args, **kwargs)
10531060
args = [args[name] for name in fn.arg_names]
1054-
args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
1061+
args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args]
10551062
# generate function def
10561063
attributes = {}
10571064
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
@@ -1110,7 +1117,7 @@ def visit_Call(self, node):
11101117
if isinstance(fn, JITFunction):
11111118
_check_fn_args(node, fn, args)
11121119
return self.call_JitFunction(fn, args, kws)
1113-
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn):
1120+
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn):
11141121
extra_kwargs = {"_builder": self.builder}
11151122
sig = inspect.signature(fn)
11161123
if '_generator' in sig.parameters:

python/triton/language/core.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,12 +701,20 @@ def get_int_dtype(bitwidth: int, signed: bool) -> dtype:
701701
raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}')
702702

703703

704+
class _value:
705+
"""Base class of values that exist in the triton IR (i.e. not constexprs).
706+
"""
707+
708+
def __init__(self, handle):
709+
self.handle = handle
710+
711+
704712
# -----------------------
705713
# tensor
706714
# -----------------------
707715

708716

709-
class tensor:
717+
class tensor(_value):
710718
"""Represents an N-dimensional array of values or pointers.
711719
712720
:code:`tensor` is the fundamental data structure in Triton programs. Most
@@ -729,7 +737,7 @@ class tensor:
729737
def __init__(self, handle, type: dtype):
730738
"""Not called by user code."""
731739
# IR handle
732-
self.handle = handle
740+
super().__init__(handle)
733741
# Block shape
734742
self.shape = type.shape if type.is_block() else ()
735743
self.numel = 1

0 commit comments

Comments
 (0)