Skip to content

Commit 3c8c0cf

Browse files
authored
[gluon] Basic multi-cta support (#8468)
Currently `num_ctas` is not propagated to the kernel launcher, so just need to add some minimal metadata.
1 parent 24bd281 commit 3c8c0cf

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

python/test/gluon/test_core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ def test_copy_kernel(layout, XBLOCK):
6767
torch.testing.assert_close(out, inp)
6868

6969

70+
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper")
71+
def test_copy_kernel_multi_cta():
72+
XBLOCK = 2048
73+
layout = ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[8], order=[0],
74+
ctas_per_cga=[2], cta_split_num=[2])
75+
76+
inp = torch.randn(XBLOCK * 4 - 7, device="cuda")
77+
out = torch.empty_like(inp)
78+
copy_kernel[(4, )](out, inp, inp.numel(), XBLOCK, layout, num_warps=layout.warps_per_cta[0], num_ctas=2)
79+
torch.testing.assert_close(out, inp)
80+
81+
7082
@gluon.jit
7183
def tma_kernel(desc):
7284
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0])

third_party/nvidia/backend/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,7 @@ def make_ttgir(mod, metadata, opt, capability):
323323

324324
pm.run(mod, 'make_ttgir')
325325
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
326-
tensordesc_meta = mod.get_tensordesc_metadata()
327-
metadata["tensordesc_meta"] = tensordesc_meta
326+
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
328327
return mod
329328

330329
def gluon_to_ttgir(self, src, metadata, options, capability):
@@ -341,6 +340,7 @@ def gluon_to_ttgir(self, src, metadata, options, capability):
341340
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
342341

343342
pm.run(mod, 'gluon_to_ttgir')
343+
metadata["cluster_dims"] = (options.num_ctas, 1, 1)
344344
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
345345
return mod
346346

0 commit comments

Comments
 (0)