Skip to content

Commit 4b1e177

Browse files
authored
[STDLIB][NFC] Use constexpr_function decorator on our own functions (#7621)
1 parent 70d83a4 commit 4b1e177

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

python/triton/language/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ def wrapper(*args, _semantic=None, **kwargs):
351351
res = f(*args, **kwargs)
352352

353353
# convert result back to a Triton constexpr:
354+
if knobs.runtime.interpret:
355+
return res # No constexpr in interpreter
354356
return constexpr(res)
355357

356358
# disguise the function as a Triton builtin to avoid raising an error

python/triton/language/standard.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
# constexpr utilities
88

99

10-
def _log2(i: core.constexpr):
10+
@core.constexpr_function
11+
def _log2(i):
1112
log2 = 0
12-
n = core.constexpr(i).value
13+
n = i
1314
while n > 1:
1415
n >>= 1
1516
log2 += 1
16-
return core.constexpr(log2)
17+
return log2
1718

1819

19-
def _is_power_of_two(i: core.constexpr):
20-
n = i.value
21-
return core.constexpr((n & (n - 1)) == 0 and n != 0)
20+
@core.constexpr_function
21+
def _is_power_of_two(i):
22+
return (i & (i - 1)) == 0 and i != 0
2223

2324

2425
# -----------------------
@@ -263,8 +264,8 @@ def _sum_combine(a, b):
263264
# sum
264265

265266

266-
def _pick_sum_dtype(in_dtype: core.constexpr, dtype: core.constexpr):
267-
dtype = core._unwrap_if_constexpr(dtype)
267+
@core.constexpr_function
268+
def _pick_sum_dtype(in_dtype, dtype):
268269
if dtype is not None:
269270
return dtype
270271

@@ -476,14 +477,13 @@ def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = co
476477
return _bitonic_merge(x, n_dims, descending, n_dims)
477478

478479

480+
@core.constexpr_function
479481
def _get_flip_dim(dim, shape):
480-
dim = core._unwrap_if_constexpr(dim)
481-
shape = core._unwrap_if_constexpr(shape)
482482
if dim is None:
483483
dim = len(shape) - 1
484484
if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
485485
dim += len(shape)
486-
return core.constexpr(dim)
486+
return dim
487487

488488

489489
@core._tensor_member_fn

0 commit comments

Comments
 (0)