Skip to content

Commit 2bd6c6f

Browse files
authored
[Frontend] Treat JITFunction as constexpr type (#6988)
This teaches a few more places in the frontend to treat JITFunction as constexprs. This allows, for example, passing lists of functions as constexprs.
1 parent 1143c03 commit 2bd6c6f

File tree

4 files changed

+30
-1
lines changed

4 files changed

+30
-1
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,26 @@ def test_aggregate_initializers():
185185
# CHECK: call @"anchor{{.*}}"([[RANGE]])
186186
value.modify(tl.arange(4, 8))
187187
anchor(value)
188+
189+
190+
@triton.jit
191+
def forward(arg):
192+
return arg
193+
194+
195+
@triton.jit
196+
def list_of_functions_constexpr(arg, fns: tl.constexpr):
197+
for i in tl.static_range(len(fns)):
198+
fns[i](arg)
199+
200+
201+
@filecheck_test
202+
@triton.jit
203+
def test_list_of_functions():
204+
# CHECK-LABEL: test_list_of_functions
205+
# CHECK: call @"list_of_functions_constexpr{{.*}}cJITFunction(test_frontend:anchor){{.*}}cJITFunction(test_frontend:forward)"
206+
207+
# CHECK-LABEL: tt.func private @"list_of_functions_constexpr
208+
# CHECK-NEXT: call @anchor
209+
# CHECK-NEXT: call @forward
210+
list_of_functions_constexpr(tl.arange(0, 4), [anchor, forward])

python/triton/compiler/code_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _is_triton_tensor(o: Any) -> bool:
4848

4949

5050
def _is_constexpr(o: Any) -> bool:
51-
return o is None or isinstance(o, (constexpr, language.core.dtype))
51+
return o is None or isinstance(o, (constexpr, language.core.dtype, JITFunction))
5252

5353

5454
def _is_non_scalar_tensor(o: Any) -> bool:

python/triton/language/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,6 +1576,7 @@ def type(self):
15761576
aggregate_value.__name__ = cls.__name__
15771577
aggregate_value.__module__ = cls.__module__
15781578
aggregate_value.__qualname__ = cls.__qualname__
1579+
aggregate_value.__doc__ = cls.__doc__
15791580

15801581
return aggregate_value
15811582

python/triton/runtime/jit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,11 @@ def cache_key(self):
711711
self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
712712
return self.hash
713713

714+
@property
715+
def type(self):
716+
from triton.language.core import constexpr
717+
return constexpr
718+
714719
def warmup(self, *args, grid, **kwargs):
715720
return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
716721

0 commit comments

Comments
 (0)