Skip to content

Commit 818e892

Browse files
authored
[Gluon] Expose PaddedSharedLayout to Gluon (#7766)
Expose PaddedSharedLayout to Gluon.
1 parent 286e91f commit 818e892

File tree

4 files changed

+194
-1
lines changed

4 files changed

+194
-1
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,6 +2297,18 @@ struct TritonGPUInferLayoutInterface
22972297
return success();
22982298
}
22992299

2300+
if (auto enc = dyn_cast<PaddedSharedEncodingAttr>(operandEncoding)) {
2301+
if (failed(checkRank(enc.getRank())))
2302+
return failure();
2303+
2304+
CTALayoutAttr ctaLayout =
2305+
permuteCTALayout(ctx, enc.getCTALayout(), order);
2306+
resultEncoding = PaddedSharedEncodingAttr::get(
2307+
ctx, enc.getIntervals(), enc.getPaddings(),
2308+
applyPermutation(invOrderUnsigned, enc.getOrder()), ctaLayout);
2309+
return success();
2310+
}
2311+
23002312
auto ll = toLinearLayout(shape, operandEncoding);
23012313
auto transposedLl = transposeLinearLayout(ll, order);
23022314
resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl));

python/src/gluon_ir.cc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ struct GluonLayouts {
9797
py::handle NVMMASharedLayout;
9898
py::handle SwizzledSharedLayout;
9999
py::handle AMDMFMALayout;
100+
py::handle PaddedSharedLayout;
100101
py::handle GluonDType;
101102

102103
GluonLayouts() {
@@ -116,6 +117,8 @@ struct GluonLayouts {
116117
SwizzledSharedLayout =
117118
py::object(layouts.attr("SwizzledSharedLayout")).release();
118119
AMDMFMALayout = py::object(amdLayouts.attr("AMDMFMALayout")).release();
120+
PaddedSharedLayout =
121+
py::object(layouts.attr("PaddedSharedLayout")).release();
119122

120123
auto core = py::module::import("triton.language.core");
121124
GluonDType = py::object(core.attr("dtype")).release();
@@ -199,7 +202,6 @@ py::object layoutToGluon(Attribute layout) {
199202
} else if (auto amdMfma = dyn_cast<ttg::AMDMfmaEncodingAttr>(layout)) {
200203
auto ctaLayout = amdMfma.getCTALayout();
201204
std::vector<unsigned> instrShape{amdMfma.getMDim(), amdMfma.getNDim()};
202-
203205
auto elemTypeOpt = amdMfma.getElementType();
204206
const char *typeName = "fp32";
205207
if (elemTypeOpt.has_value()) {
@@ -222,6 +224,19 @@ py::object layoutToGluon(Attribute layout) {
222224
toStdVector(ctaLayout.getCTAsPerCGA()),
223225
toStdVector(ctaLayout.getCTASplitNum()),
224226
toStdVector(ctaLayout.getCTAOrder()));
227+
} else if (auto paddedShared =
228+
dyn_cast<ttg::PaddedSharedEncodingAttr>(layout)) {
229+
auto ctaLayout = paddedShared.getCTALayout();
230+
std::vector<std::pair<unsigned, unsigned>> intervalPaddingPairs;
231+
for (auto [interval, padding] :
232+
llvm::zip(paddedShared.getIntervals(), paddedShared.getPaddings())) {
233+
intervalPaddingPairs.push_back({interval, padding});
234+
}
235+
return layouts.PaddedSharedLayout(intervalPaddingPairs,
236+
toStdVector(paddedShared.getOrder()),
237+
toStdVector(ctaLayout.getCTAsPerCGA()),
238+
toStdVector(ctaLayout.getCTASplitNum()),
239+
toStdVector(ctaLayout.getCTAOrder()));
225240
}
226241

227242
throw py::value_error("Unhandled encoding encountered");
@@ -338,6 +353,18 @@ void init_gluon_ir(py::module &&m) {
338353
ctx, version, warpsPerCta, tilesPerWarp, instrShape[0],
339354
instrShape[1], transposed, ctaLayout, elemType);
340355
})
356+
.def("get_padded_shared_layout",
357+
[](GluonOpBuilder &self, std::vector<unsigned> &intervals,
358+
std::vector<unsigned> &paddings, std::vector<unsigned> &order,
359+
std::vector<unsigned> &ctasPerCga,
360+
std::vector<unsigned> &ctaSplitNum,
361+
std::vector<unsigned> &ctaOrder) -> Attribute {
362+
auto ctx = self.getContext();
363+
auto ctaLayout = self.getChecked<ttg::CTALayoutAttr>(
364+
ctx, ctasPerCga, ctaSplitNum, ctaOrder);
365+
return ttg::PaddedSharedEncodingAttr::get(ctx, intervals, paddings,
366+
order, ctaLayout);
367+
})
341368
.def("get_nvmma_shared_layout",
342369
[](GluonOpBuilder &self, unsigned swizzleByteWidth,
343370
unsigned elementBitwidth, bool transposed, bool fp4Padded,

python/test/gluon/test_frontend.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2123,3 +2123,72 @@ def kernel():
21232123
}
21242124
}
21252125
""")
2126+
2127+
2128+
@gluon.jit
2129+
def padded_shared_layout_kernel():
2130+
padded_shared_layout: ttgl.constexpr = ttgl.PaddedSharedLayout(interval_padding_pairs=[[2, 1], [4, 2], [8, 4]],
2131+
order=[1, 0], ctas_per_cga=[1, 1],
2132+
cta_split_num=[1, 1], cta_order=[1, 0])
2133+
2134+
ttgl.allocate_shared_memory(ttgl.int32, [64, 64], padded_shared_layout)
2135+
2136+
2137+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
2138+
def test_padded_shared_layout(target):
2139+
# This test is used to test the construction of PaddedSharedEncodingAttr in the gluon.
2140+
module = run_parser(padded_shared_layout_kernel, target=target)
2141+
expecttest.assert_expected_inline(
2142+
anonymize_ir(module.str_nodebug()), """\
2143+
#shared = #ttg.padded_shared<[2:+1, 4:+2, 8:+4] {order = [1, 0]}>
2144+
#smem = #ttg.shared_memory
2145+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
2146+
tt.func public @padded_shared_layout_kernel() attributes {noinline = false} {
2147+
%0 = ttg.local_alloc : () -> !ttg.memdesc<64x64xi32, #shared, #smem, mutable>
2148+
tt.return
2149+
}
2150+
}
2151+
""")
2152+
2153+
2154+
@gluon.jit
2155+
def infer_layout_for_padded_shared_kernel():
2156+
layout: ttgl.constexpr = ttgl.PaddedSharedLayout(interval_padding_pairs=[[2, 1], [4, 2], [8, 4]], order=[2, 0, 1])
2157+
smem = ttgl.allocate_shared_memory(ttgl.int32, [32, 4, 32], layout)
2158+
2159+
reshaped = smem.permute((1, 0, 2))
2160+
"""
2161+
permute is [1 0 2], which means
2162+
old 1 to new 0
2163+
old 0 to new 1
2164+
old 2 to new 2
2165+
so inverseMapping[0] = 1, inverseMapping[1] = 0, inverseMapping[2] = 2
2166+
2167+
order in srcEnc is [2, 0, 1]
2168+
thus the order in dstEnc are:
2169+
newOrder[0] = inverseMapping[srcEncOrder[0]] = 2
2170+
newOrder[1] = inverseMapping[srcEncOrder[1]] = 1
2171+
newOrder[2] = inverseMapping[srcEncOrder[2]] = 0
2172+
"""
2173+
ttgl.static_assert(
2174+
reshaped.layout == ttgl.PaddedSharedLayout(interval_padding_pairs=[(2, 1), (4, 2), (8, 4)], order=[2, 1, 0]))
2175+
2176+
2177+
@pytest.mark.parametrize("target", ALL_TARGETS)
2178+
def test_infer_layout_for_padded_shared(target):
2179+
# This test is used to test the conversion to gluon object PaddedSharedLayout from PaddedSharedEncodingAttr.
2180+
# This conversion is in layoutToGluon and ttgl.permute will finally use it.
2181+
module = run_parser(infer_layout_for_padded_shared_kernel, target=target)
2182+
expecttest.assert_expected_inline(
2183+
anonymize_ir(module.str_nodebug()), """\
2184+
#shared = #ttg.padded_shared<[2:+1, 4:+2, 8:+4] {order = [2, 0, 1]}>
2185+
#shared1 = #ttg.padded_shared<[2:+1, 4:+2, 8:+4] {order = [2, 1, 0]}>
2186+
#smem = #ttg.shared_memory
2187+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
2188+
tt.func public @infer_layout_for_padded_shared_kernel() attributes {noinline = false} {
2189+
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x4x32xi32, #shared, #smem, mutable>
2190+
%1 = ttg.memdesc_trans %0 {order = array<i32: 1, 0, 2>} : !ttg.memdesc<32x4x32xi32, #shared, #smem, mutable> -> !ttg.memdesc<4x32x32xi32, #shared1, #smem, mutable>
2191+
tt.return
2192+
}
2193+
}
2194+
""")

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"NVMMADistributedLayout",
1313
"NVMMASharedLayout",
1414
"SwizzledSharedLayout",
15+
"PaddedSharedLayout",
1516
]
1617

1718

@@ -428,6 +429,90 @@ def stringify(x):
428429
return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_SSS"
429430

430431

432+
@dataclass(frozen=True, eq=True)
433+
class PaddedSharedLayout(SharedLayout):
434+
"""
435+
Represents a layout for the access to shared memory. Compared to SwizzledSharedLayout,
436+
it uses padding to avoid shared memory bank conflicts. After every interval tensor elements,
437+
the corresponding number of padding elements are inserted.
438+
If a position corresponds to multiple intervals, the padding amounts are summed.
439+
440+
In the following example of a tensor,
441+
`eM` represents original elements in the and `pN` represents padded element.
442+
443+
Before padding, the shared memory looks like:
444+
[e0, e1,
445+
e2, e3,
446+
e4, e5,
447+
e6, e7,
448+
...]
449+
450+
After padding with interval-padding list [[2, 1], [4, 2]],
451+
the shared memory will be
452+
[e0, e1, p0,
453+
e2, e3, p1, p2, p3,
454+
e4, e5, p4,
455+
e6, e7, p5, p6, p7,
456+
...]
457+
458+
Args:
459+
interval_padding_pairs (List[int]): List of [interval, padding] pair and both interval and padding must be powers of 2.
460+
order (List[int]): Order of logical tensor dimensions; fastest-varying first.
461+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
462+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
463+
cta_order (Optional[List[int]]): CTA ordering.
464+
"""
465+
interval_padding_pairs: List[List[int]]
466+
order: List[int]
467+
ctas_per_cga: Optional[List[int]] = None
468+
cta_split_num: Optional[List[int]] = None
469+
cta_order: Optional[List[int]] = None
470+
471+
def __post_init__(self):
472+
super().__setattr__("interval_padding_pairs", _unwrap_shape(self.interval_padding_pairs))
473+
super().__setattr__("order", _unwrap_if_constexpr(self.order))
474+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
475+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
476+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
477+
478+
self.verify()
479+
480+
def _to_ir(self, builder):
481+
intervals, paddings = zip(*self.interval_padding_pairs)
482+
return builder.get_padded_shared_layout(intervals, paddings, self.order, self.ctas_per_cga, self.cta_split_num,
483+
self.cta_order)
484+
485+
def mangle(self) -> str:
486+
487+
def stringify(x):
488+
if x is None:
489+
return ""
490+
return "_".join(map(str, x))
491+
492+
return f"PaddedShared_{stringify(self.interval_padding_pairs)}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_PaddedShared"
493+
494+
def verify(self):
495+
pairs = self.interval_padding_pairs
496+
assert len(pairs) > 0, "PaddedSharedLayout interval_padding_pairs must have at least one interval-padding pair"
497+
assert all(len(pair) == 2 for pair in pairs)
498+
intervals, paddings = zip(*pairs)
499+
500+
unique_intervals = list(set(intervals))
501+
assert len(unique_intervals) == len(intervals)
502+
503+
is_power_of_2 = lambda n: n > 0 and n & (n - 1) == 0
504+
assert all(is_power_of_2(n) for n in intervals), "PaddedSharedLayout interval values must all be power of two"
505+
assert all(is_power_of_2(n) for n in paddings), "PaddedSharedLayout padding values must all be power of two"
506+
507+
rank = len(self.order)
508+
assert rank > 0, "PaddedSharedLayout order must not be empty"
509+
_realize_cta_layout(self, rank)
510+
511+
assert len(self.ctas_per_cga) == rank
512+
assert len(self.cta_split_num) == rank
513+
assert len(self.cta_order) == rank
514+
515+
431516
# Python impl of LinearEncodingAttr::basesPerDim
432517
def bases_per_dim(bases, rank, skip_broadcast=True):
433518
result = [1] * rank

0 commit comments

Comments
 (0)