Skip to content

Commit 28eba22

Browse files
agron911meta-codesync[bot]
authored andcommitted
[Cherry-pick][RESOLVED] [FRONTEND] Wrap non-jit function call results in constexpr automatically (#8002) (#544)
Summary: ⚠️ **MERGE CONFLICTS DETECTED** ⚠️ This cherry-pick contains merge conflicts that require manual resolution. Original Commit: b93eefd Original Author: peterbell10 Original Date: 2025-08-28 22:30:24 +0100 **Action Required:** 1. Check out this branch locally 2. Resolve the merge conflicts in the affected files 3. Commit the resolved changes 4. Update this PR Original commit message: ``` [FRONTEND] Wrap non-jit function call results in constexpr automatically (#8002) Fixes triton-lang/triton#8001 (comment) ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. The conflicts have been committed with conflict markers for easier resolution. Pull Request resolved: #544 Reviewed By: dshi7 Differential Revision: D86147298 Pulled By: agron911 fbshipit-source-id: 7987bddfa82532da67243016570ad14f91abb2ce
1 parent a97a266 commit 28eba22

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

python/test/gluon/test_frontend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,3 +2250,11 @@ def test_infer_layout_for_padded_shared(target):
22502250
}
22512251
}
22522252
""")
2253+
2254+
2255+
@filecheck_test
2256+
@gluon.jit
2257+
def test_layout_zeros():
2258+
# CHECK: #blocked = #ttg.blocked
2259+
# CHECK: arith.constant dense<0.000000e+00> : tensor<128xf32, #blocked>
2260+
ttgl.zeros([128], ttgl.float32, layout=ttgl.BlockedLayout([1], [32], [4], [0]))

python/triton/compiler/code_generator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import builtins
23
import contextlib
34
import copy
45
import inspect
@@ -1373,7 +1374,15 @@ def call_Function(self, node, fn, args, kws):
13731374
if fn in self.builtin_namespace.values():
13741375
args = map(_unwrap_if_constexpr, args)
13751376
ret = fn(*args, **kws)
1376-
return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret
1377+
1378+
def wrap_constexpr(x):
1379+
if _is_triton_value(x):
1380+
return x
1381+
return constexpr(x)
1382+
1383+
if isinstance(ret, (builtins.tuple, language.tuple)):
1384+
return _apply_to_tuple_values(ret, wrap_constexpr)
1385+
return wrap_constexpr(ret)
13771386

13781387
def call_Method(self, node, fn, fn_self, args, kws):
13791388
if isinstance(fn, JITFunction):

python/triton/language/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3183,7 +3183,7 @@ def kernel(A, B, C, D, BLOCK: tl.constexpr):
31833183
# -----------------------
31843184

31853185

3186-
class static_range:
3186+
class static_range(base_value):
31873187
"""
31883188
Iterator that counts upward forever.
31893189
@@ -3223,7 +3223,7 @@ def __next__(self):
32233223
raise RuntimeError("static_range can only be used in @triton.jit'd functions")
32243224

32253225

3226-
class range:
3226+
class range(base_value):
32273227
"""
32283228
Iterator that counts upward forever.
32293229
@@ -3293,7 +3293,7 @@ def __next__(self):
32933293
raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
32943294

32953295

3296-
class condition:
3296+
class condition(base_value):
32973297
"""
32983298
While loop condition wrapper.
32993299

0 commit comments

Comments
 (0)