Skip to content

Commit 8a5d1ee

Browse files
authored
[BACKEND] Add support for LinearSharedEncoding (triton-lang#8116)
For now we just add it to the gluon test.
1 parent a90ac86 commit 8a5d1ee

File tree

9 files changed

+302
-0
lines changed

9 files changed

+302
-0
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,36 @@ For identity mappings a short form based on order and shape is used to increase
493493
let genVerifyDecl = 1;
494494
}
495495

496+
def SharedLinearEncodingAttr
497+
: TritonGPU_Attr<"SharedLinearEncoding", "shared_linear_encoding",
498+
[SharedEncodingTrait, DeclareLayoutEncodingMethods]> {
499+
let mnemonic = "shared_linear";
500+
501+
let description = [{
502+
Linear shared encodings mirror LinearEncodingAttr but operate on shared
503+
memory layouts. The LinearLayout parameter captures how shared memory
504+
offsets (and optionally blocks) map to logical tensor indices.
505+
}];
506+
507+
let parameters = (ins LinearLayoutParam:$linearLayout, "unsigned":$layoutAlignment);
508+
509+
let extraClassDeclaration = [{
510+
SmallVector<unsigned> basesPerDim(StringAttr dimName,
511+
bool skipBroadcast = true) const;
512+
SmallVector<unsigned> orderPerDim(StringAttr dimName,
513+
ArrayRef<unsigned> defaultOrder) const;
514+
515+
SmallVector<unsigned> getOrder() const;
516+
517+
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
518+
519+
int32_t getAlignment() const { return static_cast<int32_t>(getLayoutAlignment()); }
520+
}];
521+
522+
let genVerifyDecl = 1;
523+
let hasCustomAssemblyFormat = 1;
524+
}
525+
496526
def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> {
497527
let mnemonic = "nvmma_shared";
498528

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ SmallVector<unsigned> getOrder(SharedEncodingTrait layout,
187187
if (auto paddedEnc = dyn_cast<PaddedSharedEncodingAttr>(layout)) {
188188
return paddedEnc.getOrder();
189189
}
190+
if (auto linearEnc = dyn_cast<SharedLinearEncodingAttr>(layout)) {
191+
return linearEnc.getOrder();
192+
}
190193
if (auto sharedLayout = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
191194
if (shape.size() == 1) {
192195
return {0};
@@ -1566,6 +1569,181 @@ void SwizzledSharedEncodingAttr::print(AsmPrinter &printer) const {
15661569
printer << "}>";
15671570
}
15681571

1572+
//===----------------------------------------------------------------------===//
1573+
// SharedLinear encoding
1574+
//===----------------------------------------------------------------------===//
1575+
1576+
LogicalResult
1577+
SharedLinearEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1578+
LinearLayout linearLayout,
1579+
unsigned layoutAlignment) {
1580+
if (layoutAlignment == 0 || !llvm::isPowerOf2_32(layoutAlignment)) {
1581+
return emitError() << "alignment must be a positive power of two";
1582+
}
1583+
static const auto expectedInDims =
1584+
SmallVector<std::string>({"offset", "block"});
1585+
for (const auto &[index, dims] : llvm::enumerate(
1586+
llvm::zip(linearLayout.getInDimNames(), expectedInDims))) {
1587+
const auto &[dim, expected] = dims;
1588+
if (dim.str() != expected) {
1589+
return emitError() << "Expected input dimension " << index << " to be '"
1590+
<< expected << "'. Got " << dim;
1591+
}
1592+
}
1593+
1594+
for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) {
1595+
if (dim.str() != ("dim" + llvm::Twine(i)).str()) {
1596+
return emitError()
1597+
<< "Expected output dimensions to be ['dim0', 'dim1', ...]. Got "
1598+
<< dim << " at position " << i;
1599+
}
1600+
}
1601+
1602+
SmallVector<StringAttr> outDimNames =
1603+
llvm::to_vector(linearLayout.getOutDimNames());
1604+
if (outDimNames.empty()) {
1605+
return emitError()
1606+
<< "SharedLinearEncodingAttr requires at least one output"
1607+
" dimension.";
1608+
}
1609+
1610+
auto *ctx = outDimNames.front().getContext();
1611+
auto kOffset = StringAttr::get(ctx, "offset");
1612+
auto kBlock = StringAttr::get(ctx, "block");
1613+
1614+
if (!linearLayout.isSurjective()) {
1615+
return emitError() << "The layout must be surjective";
1616+
}
1617+
1618+
LinearLayout withoutBroadcast =
1619+
linearLayout.removeZeroBasesAlongDim(kOffset).removeZeroBasesAlongDim(
1620+
kBlock);
1621+
if (!withoutBroadcast.isInvertible()) {
1622+
return emitError()
1623+
<< "After removing the zero bases the layout must be bijective";
1624+
}
1625+
1626+
return success();
1627+
}
1628+
1629+
void SharedLinearEncodingAttr::print(AsmPrinter &printer) const {
1630+
printer << "<{";
1631+
auto layout = getLinearLayout();
1632+
auto kBlock = StringAttr::get(getContext(), "block");
1633+
auto kOffset = StringAttr::get(getContext(), "offset");
1634+
if (layout.getBases().lookup(kBlock).empty()) {
1635+
layout =
1636+
layout.sublayout({kOffset}, llvm::to_vector(layout.getOutDimNames()));
1637+
}
1638+
printLinearLayout(printer, layout);
1639+
printer << "}, alignment = " << getAlignment() << "}>";
1640+
}
1641+
1642+
Attribute SharedLinearEncodingAttr::parse(AsmParser &parser, Type type) {
1643+
if (parser.parseLess().failed())
1644+
return {};
1645+
1646+
DictionaryAttr layoutDictRaw;
1647+
if (parser.parseAttribute(layoutDictRaw).failed())
1648+
return {};
1649+
1650+
if (layoutDictRaw.get("alignment")) {
1651+
parser.emitError(parser.getCurrentLocation())
1652+
<< "alignment must be specified outside of the linear layout braces";
1653+
return {};
1654+
}
1655+
1656+
NamedAttrList layoutAttrList(layoutDictRaw.getValue());
1657+
auto *ctx = parser.getContext();
1658+
auto kBlock = StringAttr::get(ctx, "block");
1659+
if (!layoutAttrList.get(kBlock)) {
1660+
layoutAttrList.push_back({kBlock, ArrayAttr::get(ctx, {})});
1661+
}
1662+
1663+
DictionaryAttr layoutDict = layoutAttrList.getDictionary(ctx);
1664+
1665+
// Parse alignment
1666+
unsigned layoutAlignment;
1667+
if (parser.parseComma().failed())
1668+
return {};
1669+
if (parser.parseKeyword("alignment").failed() || parser.parseEqual().failed())
1670+
return {};
1671+
if (parser.parseInteger(layoutAlignment).failed())
1672+
return {};
1673+
1674+
if (parser.parseGreater().failed())
1675+
return {};
1676+
1677+
std::vector<std::string> inDimNames = {"offset", "block"};
1678+
auto maybeLL = parseLinearLayout(layoutDict, parser, inDimNames);
1679+
if (!maybeLL.has_value())
1680+
return {};
1681+
1682+
// Special case for cleaner errors
1683+
if (layoutDict.get("alignment")) {
1684+
parser.emitError(parser.getCurrentLocation())
1685+
<< "alignment must be specified outside of the linear layout braces";
1686+
return {};
1687+
}
1688+
1689+
if (layoutDict.size() != 2) {
1690+
parser.emitError(parser.getCurrentLocation())
1691+
<< "SharedLinearEncodingAttr must have exactly two attributes: offset "
1692+
"and block";
1693+
return {};
1694+
}
1695+
1696+
return parser.getChecked<SharedLinearEncodingAttr>(
1697+
parser.getContext(), std::move(*maybeLL), layoutAlignment);
1698+
}
1699+
1700+
SmallVector<unsigned>
1701+
SharedLinearEncodingAttr::basesPerDim(StringAttr dimName,
1702+
bool skipBroadcast) const {
1703+
auto ll = getLinearLayout();
1704+
auto rank = ll.getNumOutDims();
1705+
return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast);
1706+
}
1707+
1708+
SmallVector<unsigned>
1709+
SharedLinearEncodingAttr::orderPerDim(StringAttr dimName,
1710+
ArrayRef<unsigned> defaultOrder) const {
1711+
return orderPerDimImpl(getLinearLayout(), dimName, defaultOrder);
1712+
}
1713+
1714+
SmallVector<unsigned> SharedLinearEncodingAttr::getOrder() const {
1715+
auto ll = getLinearLayout();
1716+
auto rank = ll.getNumOutDims();
1717+
SmallVector<unsigned> defaultOrder(rank);
1718+
std::iota(defaultOrder.rbegin(), defaultOrder.rend(), 0);
1719+
return orderPerDim(StringAttr::get(getContext(), "offset"), defaultOrder);
1720+
}
1721+
1722+
SmallVector<unsigned> SharedLinearEncodingAttr::getCTAsPerCGA() const {
1723+
return basesPerDim(StringAttr::get(getContext(), "block"),
1724+
/*skipBroadcast=*/false);
1725+
}
1726+
1727+
SmallVector<unsigned> SharedLinearEncodingAttr::getCTAOrder() const {
1728+
return orderPerDim(StringAttr::get(getContext(), "block"), getOrder());
1729+
}
1730+
1731+
SmallVector<unsigned> SharedLinearEncodingAttr::getCTASplitNum() const {
1732+
return basesPerDim(StringAttr::get(getContext(), "block"));
1733+
}
1734+
1735+
LinearLayout
1736+
SharedLinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
1737+
auto ll = getLinearLayout();
1738+
auto outDimNames = llvm::to_vector(ll.getOutDimNames());
1739+
assert(shape.size() == outDimNames.size());
1740+
// We don't support automatic broadcasting for shared linear layouts
1741+
for (auto [size, llSize] : llvm::zip(shape, ll.getOutDimSizes())) {
1742+
assert(size == llSize);
1743+
}
1744+
return ll;
1745+
}
1746+
15691747
//===----------------------------------------------------------------------===//
15701748
// PaddedShared encoding
15711749
//===----------------------------------------------------------------------===//

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,8 @@ LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape,
12531253
"shape must be a postive power of 2");
12541254
if (auto shared = dyn_cast<SwizzledSharedEncodingAttr>(layout)) {
12551255
result = swizzledSharedToLinearLayout(shape, shared);
1256+
} else if (auto shared = dyn_cast<SharedLinearEncodingAttr>(layout)) {
1257+
result = shared.toLinearLayout(shape);
12561258
} else if (auto shared = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
12571259
result = nvmmaSharedToLinearLayout(shape, shared);
12581260
} else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {

python/src/gluon_ir.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,20 @@ void init_gluon_ir(py::module &&m) {
395395
return ttg::PaddedSharedEncodingAttr::get(ctx, intervals, paddings,
396396
ll);
397397
})
398+
.def("get_shared_linear_layout",
399+
[](GluonOpBuilder &self, std::vector<std::vector<int>> &offsetBases,
400+
std::vector<std::vector<int>> &blockBases,
401+
std::vector<int64_t> &shape, unsigned alignment) -> Attribute {
402+
auto ctx = self.getContext();
403+
auto kOffset = mlir::StringAttr::get(ctx, "offset");
404+
auto kBlock = mlir::StringAttr::get(ctx, "block");
405+
auto ll = tt::LinearLayout(
406+
{{kOffset, offsetBases}, {kBlock, blockBases}},
407+
tt::standardOutDimPairs(ctx, shape),
408+
/*requireSurjective=*/true);
409+
return self.getChecked<ttg::SharedLinearEncodingAttr>(ctx, ll,
410+
alignment);
411+
})
398412
.def("get_nvmma_shared_layout",
399413
[](GluonOpBuilder &self, unsigned swizzleByteWidth,
400414
unsigned elementBitwidth, bool transposed, bool fp4Padded,

python/test/gluon/test_lowerings.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,7 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, N: ttgl.constexpr, src_layout: ttgl.
721721
ttgl.SwizzledSharedLayout(vec=4, per_phase=2, max_phase=4, order=[0, 1]),
722722
ttgl.SwizzledSharedLayout(vec=8, per_phase=1, max_phase=8, order=[1, 0]),
723723
ttgl.SwizzledSharedLayout(vec=16, per_phase=1, max_phase=16, order=[1, 0]),
724+
"shared_linear_layout",
724725
])
725726

726727

@@ -733,6 +734,20 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, N: ttgl.constexpr, src_layout: ttgl.
733734
@pytest.mark.parametrize("dist_layout", _ld_st_dot_layouts + _ld_st_mma_layouts)
734735
@pytest.mark.parametrize("shared_layout", _ld_st_shared_layouts)
735736
def test_local_load_store_2d_layouts(shape, dtype, dist_layout, shared_layout, device):
737+
if shared_layout == "shared_linear_layout":
738+
rank = len(shape)
739+
assert rank == 2
740+
offset_bases = []
741+
for dim, size in enumerate(shape):
742+
assert size > 0 and (size & (size - 1)) == 0
743+
stride = 1
744+
while stride < size:
745+
basis = [0] * rank
746+
basis[dim] = stride
747+
offset_bases.append(basis)
748+
stride <<= 1
749+
shared_layout = ttgl.SharedLinearLayout(offset_bases=offset_bases, block_bases=[], shape=list(shape))
750+
736751
if isinstance(shared_layout, ttgl.NVMMASharedLayout):
737752
contig_dim = 0 if shared_layout.transposed else 1
738753
if shape[contig_dim] < (8 * shared_layout.swizzle_byte_width) / shared_layout.element_bitwidth:

python/triton/experimental/gluon/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
NVMMASharedLayout,
8484
SwizzledSharedLayout,
8585
PaddedSharedLayout,
86+
SharedLinearLayout,
8687
)
8788
from ._math import (
8889
umulhi,

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,45 @@ def __hash__(self):
587587
tuple(map(tuple, self.block_bases)), tuple(self.shape)))
588588

589589

590+
@dataclass(frozen=True)
591+
class SharedLinearLayout(SharedLayout):
592+
"""Represents a shared memory layout defined via an explicit LinearLayout."""
593+
594+
offset_bases: List[List[int]]
595+
block_bases: List[List[int]]
596+
shape: List[int]
597+
alignment: int = 16
598+
599+
def __post_init__(self):
600+
super().__setattr__("offset_bases", _unwrap_shape(self.offset_bases))
601+
super().__setattr__("block_bases", _unwrap_shape(self.block_bases))
602+
super().__setattr__("shape", _unwrap_shape(self.shape))
603+
super().__setattr__("alignment", _unwrap_if_constexpr(self.alignment))
604+
605+
rank = len(self.shape)
606+
assert rank > 0, "SharedLinearLayout shape must not be empty"
607+
for basis in self.offset_bases:
608+
assert len(basis) == rank
609+
for basis in self.block_bases:
610+
assert len(basis) == rank
611+
assert self.alignment > 0 and (self.alignment & (self.alignment - 1)) == 0, \
612+
"SharedLinearLayout alignment must be a positive power of two"
613+
614+
def _to_ir(self, builder):
615+
return builder.get_shared_linear_layout(self.offset_bases, self.block_bases, self.shape, self.alignment)
616+
617+
def mangle(self) -> str:
618+
return f"SharedLinear_{self.offset_bases}_{self.block_bases}_{self.shape}_{self.alignment}_SharedLinear"
619+
620+
def __hash__(self):
621+
return hash((
622+
tuple(map(tuple, self.offset_bases)),
623+
tuple(map(tuple, self.block_bases)),
624+
tuple(self.shape),
625+
self.alignment,
626+
))
627+
628+
590629
# Python impl of LinearEncodingAttr::basesPerDim
591630
def bases_per_dim(bases, rank, skip_broadcast=True):
592631
result = [1] * rank

test/TritonGPU/invalid.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,9 @@ tt.func @async_copy_invalid_other_type(%input: tensor<64x64x!tt.ptr<f16>, #block
466466
#shared = #ttg.padded_shared<[4:+4] {offset=[[1, 0], [2, 0], [0, 1], [0, 2]], block=[]}>
467467
// expected-error @below {{Mismatch in expected shape for dimension 0. Expected: 4, got: 8}}
468468
!out_dim_too_large = !ttg.memdesc<8x8xf32, #shared, #ttg.shared_memory>
469+
470+
// -----
471+
472+
// expected-error @below {{alignment must be specified outside of the linear layout braces}}
473+
#shared = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [1, 0], [2, 0]], block = [], alignment = 16}>
474+
!alignment_in_layout = !ttg.memdesc<4x4xf32, #shared, #ttg.shared_memory>

test/TritonGPU/ops.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,20 @@ tt.func @async_commit_group(%arg0: !ttg.async.token) {
257257
%1 = ttg.async_commit_group
258258
tt.return
259259
}
260+
261+
// -----
262+
263+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 2], warpsPerCTA = [1, 1], order = [1, 0]}>
264+
#shared = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [1, 0], [2, 2]]}, alignment = 16>
265+
#smem = #ttg.shared_memory
266+
267+
module attributes {"ttg.threads-per-warp" = 4 : i32, "ttg.num-warps" = 1 : i32} {
268+
tt.func @round_trip(%arg0: tensor<4x4xf32, #blocked>) -> tensor<4x4xf32, #blocked> {
269+
// CHECK: ttg.local_alloc
270+
// CHECK-SAME: !ttg.memdesc<4x4xf32, #shared
271+
%alloc = ttg.local_alloc %arg0 : (tensor<4x4xf32, #blocked>) -> !ttg.memdesc<4x4xf32, #shared, #smem, mutable>
272+
ttg.local_store %arg0, %alloc : tensor<4x4xf32, #blocked> -> !ttg.memdesc<4x4xf32, #shared, #smem, mutable>
273+
%loaded = ttg.local_load %alloc : !ttg.memdesc<4x4xf32, #shared, #smem, mutable> -> tensor<4x4xf32, #blocked>
274+
tt.return %loaded : tensor<4x4xf32, #blocked>
275+
}
276+
}

0 commit comments

Comments
 (0)