Skip to content

Commit 559237b

Browse files
authored
[Gluon] Add constexpr_function and static_range (#7531)
These are used in the attention tutorial, but ideally everything should come from the `gluon.language` module.
1 parent 1031dc7 commit 559237b

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

python/triton/experimental/gluon/language/_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import triton.language.core as tl_core
1313
from triton.language.core import (
1414
constexpr,
15+
constexpr_function,
1516
base_value,
1617
base_type,
1718
dtype,
@@ -38,6 +39,7 @@
3839
float64,
3940
_unwrap_if_constexpr,
4041
_unwrap_shape,
42+
static_range,
4143
tensor,
4244
tuple,
4345
tuple_type,
@@ -68,6 +70,7 @@
6870

6971
__all__ = [
7072
"constexpr",
73+
"constexpr_function",
7174
"base_value",
7275
"base_type",
7376
"dtype",
@@ -105,6 +108,7 @@
105108
"allocate_shared_memory",
106109
"set_auto_layout",
107110
"shared_memory_descriptor",
111+
"static_range",
108112
"warp_specialize",
109113
*_IMPORT_FROM_TRITON,
110114
]

python/tutorials/gluon/01-attention-forward.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
import triton
3-
import triton.language as tl
43
import pytest
54
import itertools
65

@@ -25,7 +24,7 @@
2524
# ===-----------------------------------------------------------------------===#
2625

2726

28-
@tl.constexpr_function
27+
@gl.constexpr_function
2928
def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
3029
assert len(shape) == 2, "expected a 2D tensor"
3130
assert num_warps in [4, 8], "expected 4 or 8 warps"
@@ -61,15 +60,15 @@ def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
6160
)
6261

6362

64-
@tl.constexpr_function
63+
@gl.constexpr_function
6564
def get_mma_instr_shape(shape, element_ty):
6665
m = 128 if shape[0] >= 128 else 64
6766
n = 256 if shape[1] >= 256 else shape[1]
6867
k = 256 // element_ty.primitive_bitwidth
6968
return (m, n, k)
7069

7170

72-
@tl.constexpr_function
71+
@gl.constexpr_function
7372
def get_nvmma_layout(shape, element_ty, order=[1, 0], fp4_padded=False):
7473
packing_factor = 2 if fp4_padded else 1
7574

@@ -99,7 +98,7 @@ def get_nvmma_layout(shape, element_ty, order=[1, 0], fp4_padded=False):
9998
)
10099

101100

102-
@tl.constexpr_function
101+
@gl.constexpr_function
103102
def get_mma_reg_layout(shape, num_warps, dtype=gl.float32):
104103
instr_shape = get_mma_instr_shape(shape, dtype)
105104
return get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps)
@@ -133,7 +132,7 @@ def alloc(shape: gl.constexpr, dtype: gl.constexpr, layout: gl.constexpr, num_bu
133132
mem = alloc_fn(dtype, [num_buffers] + shape, layout)
134133
ready_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
135134
empty_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
136-
for i in tl.static_range(num_buffers):
135+
for i in gl.static_range(num_buffers):
137136
mbarrier.init(ready_bars.index(i), count=1)
138137
mbarrier.init(empty_bars.index(i), count=num_consumers)
139138
mbarrier.arrive(empty_bars.index(i), count=num_consumers)
@@ -179,7 +178,7 @@ def create_consumer(self):
179178
def release(self):
180179
if isinstance(self.mem, gl.shared_memory_descriptor):
181180
self.mem._keep_alive()
182-
for i in tl.static_range(self.num_buffers):
181+
for i in gl.static_range(self.num_buffers):
183182
mbarrier.invalidate(self.ready_bars.index(i))
184183
mbarrier.invalidate(self.empty_bars.index(i))
185184

@@ -847,7 +846,7 @@ def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
847846
mbarrier.arrive(corr_bar, count=1)
848847
alpha = gl.convert_layout(alpha.reshape([config.SPLIT_M]), alpha_layout)
849848

850-
for i in tl.static_range(config.SPLIT_D_FACTOR):
849+
for i in gl.static_range(config.SPLIT_D_FACTOR):
851850
o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D)
852851
o = o_ref.load(config.o_splitn_layout)
853852
o = _mul_f32x2(o, alpha[:, None])
@@ -882,7 +881,7 @@ def _attn_fwd_correction_epilogue(config, prog, s_tmem, M, corr_consumer, epi_pr
882881
SPLIT_N: gl.constexpr = o_smem.type.shape[1] // SPLIT_N_FACTOR
883882

884883
scale = 1 / l_i
885-
for i in tl.static_range(SPLIT_N_FACTOR):
884+
for i in gl.static_range(SPLIT_N_FACTOR):
886885
o_ref = o_tmem.slice(i * SPLIT_N, SPLIT_N)
887886
o = o_ref.load(config.o_splitn_layout)
888887
o = _mul_f32x2(o, scale[:, None])
@@ -992,7 +991,7 @@ def attention_kernel( #
992991
def torch_dtype_to_triton(dtype):
993992
if dtype == torch.float8_e5m2:
994993
return gl.float8e5
995-
return getattr(tl, str(dtype).split('.')[1])
994+
return getattr(gl, str(dtype).split('.')[1])
996995

997996

998997
def make_tensor_desc(x, shape, strides, block_shape):

0 commit comments

Comments
 (0)