Skip to content

Commit a9c3322

Browse files
authored
[Gluon] Fix a few things in the translator (#8569)
* Add missing scatter conversion * Add mma_v2 path * Fix TMEM scales register layout * Change `convert_triton_to_gluon` to accept multiple root kernels, allowing a single source to be generated from them (which reuses functions across them) * Add missing `fence_async_shared` in TMA store * Add missing APIs for scales layout class * Fix ttgl.store broadcasting of scalars * Fix CTA layout canonicalization in Gluon
1 parent 1514139 commit a9c3322

File tree

8 files changed

+127
-18
lines changed

8 files changed

+127
-18
lines changed

python/src/gluon_ir.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ struct GluonLayouts {
104104
py::handle DistributedLinearLayout;
105105
py::handle DotOperandLayout;
106106
py::handle NVMMADistributedLayout;
107+
py::handle TensorMemoryScalesLayout;
108+
py::handle TensorMemoryLayout;
107109
py::handle NVMMASharedLayout;
108110
py::handle SwizzledSharedLayout;
109111
py::handle SharedLinearLayout;
@@ -116,6 +118,8 @@ struct GluonLayouts {
116118
py::module::import("triton.experimental.gluon.language._layouts");
117119
auto amdLayouts =
118120
py::module::import("triton.experimental.gluon.language.amd._layouts");
121+
auto blackwellLayouts = py::module::import(
122+
"triton.experimental.gluon.language.nvidia.blackwell");
119123
AutoLayout = py::object(layouts.attr("AutoLayout")).release();
120124
BlockedLayout = py::object(layouts.attr("BlockedLayout")).release();
121125
SliceLayout = py::object(layouts.attr("SliceLayout")).release();
@@ -124,6 +128,10 @@ struct GluonLayouts {
124128
DotOperandLayout = py::object(layouts.attr("DotOperandLayout")).release();
125129
NVMMADistributedLayout =
126130
py::object(layouts.attr("NVMMADistributedLayout")).release();
131+
TensorMemoryScalesLayout =
132+
py::object(blackwellLayouts.attr("TensorMemoryScalesLayout")).release();
133+
TensorMemoryLayout =
134+
py::object(blackwellLayouts.attr("TensorMemoryLayout")).release();
127135
NVMMASharedLayout = py::object(layouts.attr("NVMMASharedLayout")).release();
128136
SwizzledSharedLayout =
129137
py::object(layouts.attr("SwizzledSharedLayout")).release();
@@ -256,6 +264,15 @@ py::object layoutToGluon(Attribute layout) {
256264
return layouts.PaddedSharedLayout(intervalPaddingPairs,
257265
ll.getBases().lookup(kOffset),
258266
ll.getBases().lookup(kBlock), shape);
267+
} else if (auto tmemScales =
268+
dyn_cast<ttng::TensorMemoryScalesEncodingAttr>(layout)) {
269+
return layouts.TensorMemoryScalesLayout(std::vector<unsigned>{
270+
tmemScales.getCTASplitM(), tmemScales.getCTASplitN()});
271+
} else if (auto tmem = dyn_cast<ttng::TensorMemoryEncodingAttr>(layout)) {
272+
return layouts.TensorMemoryLayout(
273+
std::vector<unsigned>{tmem.getBlockM(), tmem.getBlockN()},
274+
tmem.getColStride(),
275+
std::vector<unsigned>{tmem.getCTASplitM(), tmem.getCTASplitN()});
259276
}
260277

261278
throw py::value_error("Unhandled encoding encountered");

python/test/unit/tools/test_triton_to_gluon.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
def convert_kernel(kernel, kernel_name, tmp_path):
16-
converted = convert_triton_to_gluon(kernel)
16+
converted = convert_triton_to_gluon([kernel])
1717

1818
# Write converted kernel to a file so @gluon.jit can retrieve source
1919
mod_path = tmp_path / "converted_kernel.py"
@@ -52,7 +52,7 @@ def test_simple_kernel(tmp_path):
5252
ref = torch.empty_like(x)
5353
add_kernel[grid](x, y, ref, n, BLOCK)
5454

55-
torch.testing.assert_close(out, ref)
55+
torch.testing.assert_close(out, ref, atol=0, rtol=0)
5656

5757

5858
@triton.jit
@@ -85,7 +85,7 @@ def test_triton_to_gluon_dot_minimal(tmp_path):
8585

8686
ref = torch.empty_like(c)
8787
matmul_tile_kernel[grid](a, b, ref, M, N, K, num_warps=8)
88-
torch.testing.assert_close(c, ref)
88+
torch.testing.assert_close(c, ref, atol=0, rtol=0)
8989

9090

9191
@triton.jit
@@ -153,7 +153,7 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
153153
ref = torch.empty_like(output)
154154
matmul_kernel[grid](a, b, ref, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0),
155155
output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K)
156-
torch.testing.assert_close(output, ref)
156+
torch.testing.assert_close(output, ref, atol=0, rtol=0)
157157

158158

159159
@triton.jit
@@ -177,7 +177,7 @@ def test_triton_to_gluon_descriptor_roundtrip(tmp_path):
177177
y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16)
178178
desc_ref = TensorDescriptor(y_ref, y_ref.shape, y_ref.stride(), block_shape)
179179
descriptor_store_kernel[grid](desc_ref, M, N, 1.0)
180-
torch.testing.assert_close(y, y_ref)
180+
torch.testing.assert_close(y, y_ref, atol=0, rtol=0)
181181

182182

183183
@triton.jit
@@ -204,7 +204,7 @@ def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path):
204204
y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16)
205205
desc_ref = TensorDescriptor(y_ref, y_ref.shape, y_ref.stride(), block_shape)
206206
descriptor_copy_kernel[grid](in_desc, desc_ref, M, N)
207-
torch.testing.assert_close(y, y_ref)
207+
torch.testing.assert_close(y, y_ref, atol=0, rtol=0)
208208

209209

210210
@triton.jit
@@ -232,7 +232,7 @@ def test_triton_reshape_trans(tmp_path):
232232
kernel[grid](x, y, out, n, BLOCK)
233233
ref = torch.empty_like(x)
234234
reshape_trans_kernel[grid](x, y, ref, n, BLOCK)
235-
torch.testing.assert_close(out, ref)
235+
torch.testing.assert_close(out, ref, atol=0, rtol=0)
236236

237237

238238
BLOCK_SPLIT = tl.constexpr(256)
@@ -262,7 +262,7 @@ def test_split(tmp_path):
262262
kernel[grid](x, out)
263263
ref = torch.empty_like(x[:n])
264264
split_kernel[grid](x, ref)
265-
torch.testing.assert_close(out, ref)
265+
torch.testing.assert_close(out, ref, atol=0, rtol=0)
266266

267267

268268
@triton.jit
@@ -281,4 +281,23 @@ def test_reduce_to_scalar(tmp_path):
281281
kernel[grid](out)
282282
ref = torch.empty_like(out)
283283
reduce_to_scalar_kernel[grid](ref)
284-
torch.testing.assert_close(out, ref)
284+
torch.testing.assert_close(out, ref, atol=0, rtol=0)
285+
286+
287+
@triton.jit
288+
def num_threads_kernel(out_ptr):
289+
num_threads: tl.constexpr = tl.extra.cuda.num_threads()
290+
offs = tl.arange(0, num_threads)
291+
tl.store(out_ptr + offs, 1)
292+
293+
294+
@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell")
295+
def test_num_threads(tmp_path):
296+
kernel = convert_kernel(num_threads_kernel, "num_threads_kernel", tmp_path)
297+
298+
num_threads = 256
299+
out = torch.empty(num_threads, dtype=torch.int32, device="cuda")
300+
kernel[(1, )](out, num_warps=num_threads // 32)
301+
ref = torch.empty_like(out)
302+
num_threads_kernel[(1, )](ref, num_warps=num_threads // 32)
303+
torch.testing.assert_close(out, ref, atol=0, rtol=0)

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ def _realize_cta_layout(layout, rank):
99
ctas_per_cga = layout.ctas_per_cga or [1] * rank
1010
cta_split_num = layout.cta_split_num or [1] * rank
1111
cta_order = layout.cta_order or list(reversed(range(rank)))
12+
# Canonicalize CTA order to [n,n-1,...,0] if CTAsPerCGA is [1...1]. This matches logic in C++.
13+
if all(num_cta == 1 for num_cta in ctas_per_cga):
14+
cta_order = list(range(rank - 1, -1, -1))
1215
object.__setattr__(layout, "ctas_per_cga", ctas_per_cga)
1316
object.__setattr__(layout, "cta_split_num", cta_split_num)
1417
object.__setattr__(layout, "cta_order", cta_order)

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,11 @@ def _check_same_layout(xs):
416416
_check(all(l == l0 for l in layouts[1:]),
417417
lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
418418

419+
def _store_legacy(self, ptr, val, mask, boundary_check, cache, eviction):
420+
if ptr.type.is_block() and not val.type.is_block():
421+
val = self.splat(val, ptr.type.get_block_shapes(), ptr.type.layout)
422+
return super()._store_legacy(ptr, val, mask, boundary_check, cache, eviction)
423+
419424
def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
420425
reverse: bool) -> Tuple[TensorTy, ...]:
421426
shape = inputs[0].type.shape

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def mangle(self) -> str:
6868
cta_split_str = (f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else "")
6969
return f"TL{block_str}{stride_str}{cta_split_str}TL"
7070

71+
def __hash__(self):
72+
return hash((self.block, self.col_stride, self.cta_split_num))
73+
7174

7275
@dataclass(frozen=True, eq=True)
7376
class TensorMemoryScalesLayout:
@@ -91,6 +94,9 @@ def mangle(self) -> str:
9194
cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
9295
return f"TLS{cta_split_str}TLS"
9396

97+
def __hash__(self):
98+
return hash(self.cta_split_num)
99+
94100

95101
@constexpr_function
96102
def get_tmem_reg_layout(

python/triton/experimental/gluon/language/nvidia/hopper/tma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def make_tensor_descriptor(
116116
_semantic=None,
117117
) -> tensor_descriptor:
118118
padding_option = _unwrap_if_constexpr(padding_option)
119+
block_shape = _unwrap_if_constexpr(block_shape)
119120

120121
ndim = len(shape)
121122
if not (1 <= ndim <= 5):

python/triton/tools/triton_to_gluon_translater/translator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
159159
if resolved_callable is triton.language.core.static_range:
160160
return self.forward_call(node, self.ttgl_attr("static_range"))
161161
else:
162-
if isinstance(node.func, ast.Attribute) and node.func.attr in ["store", "load", "gather"]:
162+
if isinstance(node.func, ast.Attribute) and node.func.attr in ["store", "load", "gather", "scatter"]:
163163
helper_name = "tl_obj_" + node.func.attr
164164
return ast.Call(
165165
func=ast.Name(id=helper_name, ctx=ast.Load()),
@@ -378,10 +378,10 @@ def visit_Call(self, call_node: ast.Call) -> ast.AST:
378378
return results
379379

380380

381-
def convert_triton_to_gluon(src: triton.runtime.jit.JITCallable) -> str:
381+
def convert_triton_to_gluon(src: list[triton.runtime.jit.JITCallable]) -> str:
382382
"""Convert a Triton JIT entry point into a Gluon source string."""
383383
shared_jit_set: set = set()
384-
function_queue: list = [src]
384+
function_queue: list = list(src)
385385
constexpr_globals: dict = {}
386386
out = ""
387387
# Process discovered callee JITFunctions, converting and appending them

python/triton/tools/triton_to_gluon_translater/translator_helpers.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,32 @@
1010
tcgen05_mma_scaled,
1111
tcgen05_commit,
1212
)
13-
from triton.experimental.gluon.language.nvidia.hopper import tma
13+
from triton.experimental.gluon.language.nvidia.ampere import mma_v2
14+
from triton.experimental.gluon.language.nvidia.hopper import tma, fence_async_shared
1415
from triton.experimental.gluon.language.nvidia.blackwell import tma as tma_blackwell
1516

1617

18+
@gluon.jit
19+
def tl_dot_mma_sync(a, b, acc_init=None, input_precision=None):
20+
mma_layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(
21+
version=[2, 0],
22+
warps_per_cta=[ttgl.num_warps(), 1],
23+
instr_shape=[16, 8],
24+
)
25+
a_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=mma_layout, operand_index=0, k_width=2)
26+
b_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=mma_layout, operand_index=1, k_width=2)
27+
a = ttgl.convert_layout(a, a_layout)
28+
b = ttgl.convert_layout(b, b_layout)
29+
if acc_init is not None:
30+
acc = ttgl.convert_layout(acc_init, mma_layout)
31+
else:
32+
acc = ttgl.full([a.shape[0], a.shape[1], b.shape[2]], 0.0, ttgl.float32, layout=mma_layout)
33+
result = mma_v2(a, b, acc, input_precision)
34+
if acc is not None:
35+
result = ttgl.convert_layout(result, acc_init.type.layout)
36+
return result
37+
38+
1739
@gluon.constexpr_function
1840
def get_swizzle_byte_width(bitwidth):
1941
swizzle = min(bitwidth, 128)
@@ -22,8 +44,8 @@ def get_swizzle_byte_width(bitwidth):
2244

2345

2446
@gluon.jit
25-
def tl_dot(a, b, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=ttgl.float32):
26-
# TODO: check if MMAv5 cannot be used and fallback to mmav2
47+
def tl_dot_blackwell(a, b, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None,
48+
out_dtype=ttgl.float32):
2749
M: ttgl.constexpr = a.type.shape[0]
2850
N: ttgl.constexpr = b.type.shape[1]
2951
K: ttgl.constexpr = a.type.shape[1]
@@ -59,6 +81,19 @@ def tl_dot(a, b, acc=None, input_precision=None, allow_tf32=None, max_num_imprec
5981
return out
6082

6183

84+
@gluon.jit
85+
def tl_dot(a, b, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=ttgl.float32):
86+
if ttgl.num_warps() < 4:
87+
return tl_dot_mma_sync(a, b, acc, input_precision)
88+
else:
89+
return tl_dot_blackwell(a, b, acc, input_precision, allow_tf32, max_num_imprecise_acc, out_dtype)
90+
91+
92+
@gluon.constexpr_function
93+
def _constexpr_min(a, b):
94+
return min(a, b)
95+
96+
6297
@gluon.jit
6398
def tl_dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True,
6499
rhs_k_pack=True, out_dtype=ttgl.float32):
@@ -114,9 +149,9 @@ def get_num_threads_per_warp() -> ttgl.constexpr:
114149
return ttgl.constexpr(32)
115150

116151

117-
@gluon.constexpr_function
118-
def get_num_threads_per_program():
119-
return ttgl.num_warps() * get_num_threads_per_warp()
152+
@ttgl._core.builtin
153+
def get_num_threads_per_program(_semantic=None, _generator=None):
154+
return ttgl.num_warps(_semantic=_semantic, _generator=_generator) * get_num_threads_per_warp(_semantic=_semantic)
120155

121156

122157
@gluon.constexpr_function
@@ -180,9 +215,32 @@ def tl_obj_gather(obj, x_offsets, y_offset):
180215
return obj.gather(x_offsets, y_offset)
181216

182217

218+
@gluon.jit
219+
def tl_obj_scatter(obj, value, x_offsets, y_offset):
220+
if isinstance(obj, ttgl.nvidia.hopper.tma.tensor_descriptor):
221+
desc = obj
222+
desc_shape: ttgl.constexpr = [x_offsets.shape[0], desc.block_shape[1]]
223+
alloc = ttgl.allocate_shared_memory(desc.dtype, desc_shape, desc.layout, value)
224+
fence_async_shared()
225+
x_offsets_layout: ttgl.constexpr = ttgl.SliceLayout(
226+
0, ttgl.BlockedLayout([1, 4], [get_num_threads_per_warp(), 1], [1, ttgl.num_warps()], [1, 0]))
227+
x_offsets = ttgl.convert_layout(x_offsets, x_offsets_layout)
228+
tma_blackwell.async_scatter(desc, x_offsets, y_offset, alloc)
229+
tma.store_wait(0)
230+
else:
231+
obj.scatter(value, x_offsets, y_offset)
232+
233+
234+
@ttgl._core.builtin
235+
def tl_make_tensor_descriptor(base, shape, strides, block_shape, padding_option="zero", _semantic=None):
236+
layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, base.dtype.element_ty)
237+
return tma.make_tensor_descriptor(base, shape, strides, block_shape, layout, padding_option, _semantic=_semantic)
238+
239+
183240
@gluon.jit
184241
def tl_store_tensor_descriptor(desc, offsets, value):
185242
alloc = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value)
243+
fence_async_shared()
186244
tma.async_copy_shared_to_global(desc, offsets, alloc)
187245
tma.store_wait(0)
188246
alloc._keep_alive()

0 commit comments

Comments
 (0)