Skip to content

Commit 494b203

Browse files
agron911meta-codesync[bot]
authored andcommitted
[Cherry-pick][RESOLVED] [GLUON] Fix getting layout from a SwizzledSharedLayout (#8003) (#549)
Summary: ⚠️ **MERGE CONFLICTS DETECTED** ⚠️ This cherry-pick contains merge conflicts that require manual resolution. Original Commit: 642d59c Original Author: Pengzhan Zhao Original Date: 2025-08-29 10:11:54 -0700 **Action Required:** 1. Check out this branch locally 2. Resolve the merge conflicts in the affected files 3. Commit the resolved changes 4. Update this PR Original commit message: ``` [GLUON] Fix getting layout from a SwizzledSharedLayout (#8003) `layoutToGluon` will seg fault when taking a `SwizzledSharedLayout` attribute. Found this issue while using `permute` on a shared memory with this attribute. ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. The conflicts have been committed with conflict markers for easier resolution. Pull Request resolved: #549 Reviewed By: njriasan Differential Revision: D86216032 Pulled By: agron911 fbshipit-source-id: 54a04ff4999bb25c9b45d890636771968678518a
1 parent 9a0e0cb commit 494b203

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

python/src/gluon_ir.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,11 @@ py::object layoutToGluon(Attribute layout) {
191191
toStdVector(ctaLayout.getCTAOrder()));
192192
} else if (auto swizzled =
193193
dyn_cast<ttg::SwizzledSharedEncodingAttr>(layout)) {
194-
auto ctaLayout = nvmma.getCTALayout();
194+
auto ctaLayout = swizzled.getCTALayout();
195195
return layouts.SwizzledSharedLayout(
196196
swizzled.getVec(), swizzled.getPerPhase(), swizzled.getMaxPhase(),
197-
swizzled.getOrder(), toStdVector(ctaLayout.getCTAsPerCGA()),
197+
toStdVector(swizzled.getOrder()),
198+
toStdVector(ctaLayout.getCTAsPerCGA()),
198199
toStdVector(ctaLayout.getCTASplitNum()),
199200
toStdVector(ctaLayout.getCTAOrder()));
200201
} else if (auto autoEnc = dyn_cast<gluon::AutoEncodingAttr>(layout)) {

python/test/gluon/test_frontend.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,32 @@ def test_shared_memory_index(target):
317317
""")
318318

319319

320+
@gluon.jit
321+
def shared_memory_permute_kernel():
322+
layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0])
323+
smem = ttgl.allocate_shared_memory(ttgl.float16, [4, 128], layout)
324+
perm = smem.permute((1, 0))
325+
ttgl.static_assert(perm.layout == ttgl.SwizzledSharedLayout(1, 1, 1, [0, 1]))
326+
327+
328+
@pytest.mark.parametrize("target", ALL_TARGETS)
329+
def test_shared_memory_permute(target):
330+
mod = run_parser(shared_memory_permute_kernel, target=target)
331+
expecttest.assert_expected_inline(
332+
anonymize_ir(mod.str_nodebug()), """\
333+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
334+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
335+
#smem = #ttg.shared_memory
336+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
337+
tt.func public @shared_memory_permute_kernel() attributes {noinline = false} {
338+
%0 = ttg.local_alloc : () -> !ttg.memdesc<4x128xf16, #shared, #smem, mutable>
339+
%1 = ttg.memdesc_trans %0 {order = array<i32: 1, 0>} : !ttg.memdesc<4x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x4xf16, #shared1, #smem, mutable>
340+
tt.return
341+
}
342+
}
343+
""")
344+
345+
320346
@gluon.jit
321347
def shared_memory_cast_kernel():
322348
layout_a: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=8,
@@ -809,6 +835,24 @@ def kernel():
809835
assert "order must be a permutation of 0..(rank-1), but was [1]" in str(e.value.__cause__)
810836

811837

838+
def test_tensor_layout_type_changed():
839+
840+
@gluon.jit
841+
def kernel():
842+
layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[1, 32],
843+
warps_per_cta=[1, 4], order=[1, 0])
844+
x = ttgl.zeros([128], ttgl.float32)
845+
y = ttgl.zeros([128, 128], ttgl.float32, layout=layout)
846+
c = ttgl.to_tensor(True)
847+
while c:
848+
x = x + y.sum(axis=0)
849+
850+
with pytest.raises(CompilationError) as e:
851+
run_parser(kernel)
852+
853+
assert "Loop-carried variable x has initial type" in str(e.value)
854+
855+
812856
@gluon.jit
813857
def tmem_index_kernel():
814858
layout: ttgl.constexpr = TensorMemoryLayout(block=[128, 128], unpacked=True)

0 commit comments

Comments
 (0)