Skip to content

Commit 526d168

Browse files
authored
[LAYOUTS] Add a shared layout for padding (#7212)
This commit adds a new shared memory layout for padding. Padding cannot be represented with linear layout, so we need to define it at a parallel level with the swizzled shared layout. Intermediate lowering steps don't need to concern about the exact padding actually; only when we are making the 1-D physical allocation and creating pointers for indexing we then need to factor in the padding. It means we can leverage existing linear layout facilities for reasoning the element mapping.
1 parent 72764da commit 526d168

File tree

12 files changed

+460
-81
lines changed

12 files changed

+460
-81
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -537,12 +537,6 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
537537
const TargetInfoBase &target,
538538
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
539539

540-
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
541-
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
542-
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
543-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
544-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
545-
546540
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
547541
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
548542
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,21 +167,22 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
167167
];
168168
}
169169

170-
def SwizzledSharedEncodingAttr :
171-
TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
170+
def SwizzledSharedEncodingAttr
171+
: TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
172+
[SharedEncodingTrait, LayoutEncodingTrait]> {
172173
let mnemonic = "swizzled_shared";
173174

174175
let description = [{
175176
An encoding for tensors whose elements may be simultaneously accessed by
176-
different cuda threads in the programs, via shared memory. In other words,
177+
different GPU threads in the programs, via shared memory. In other words,
177178
for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
178179

179180
In order to avoid shared memory bank conflicts, elements may be swizzled.
180181
Here are some examples. In all cases, the input tensor is [0, 1, ..., n-1].
181182

182183
1. Basic swizzling
183184

184-
#shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}>
185+
#ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}>
185186
[ 0, 1, 2, 3], // xor with 0
186187
[ 5, 4, 7, 6], // xor with 1
187188
[10, 11, 8, 9], // xor with 2
@@ -192,7 +193,7 @@ out[r][c^r]).
192193

193194
2. Multiple rows per phase
194195

195-
#shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}>
196+
#ttg.swizzled_shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}>
196197
[ 0, 1, 2, 3], // phase 0 (xor with 0)
197198
[ 4, 5, 6, 7],
198199
[ 9, 8, 11, 10], // phase 1 (xor with 1)
@@ -203,7 +204,7 @@ means that pairs of 2 rows get the same swizzling.
203204

204205
3. Max-phase applied
205206

206-
$shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}>
207+
#ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}>
207208
[ 0, 1, 2, 3], // phase 0 (xor with 0)
208209
[ 5, 4, 7, 6], // phase 1 (xor with 1)
209210
[ 8, 9, 10, 11], // phase 0
@@ -218,7 +219,7 @@ effect of limiting the maximum value of the xor to m-1.
218219

219220
4. Max-phase and per-phase
220221

221-
#shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}>
222+
#ttg.swizzled_shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}>
222223
[ 0, 1, 2, 3], // phase 0 (xor with 0)
223224
[ 4, 5, 6, 7], // phase 0
224225
[ 9, 8, 11, 10], // phase 1 (xor with 1)
@@ -234,7 +235,7 @@ maximum value of maxPhase-1. In other words, elements of row r are xor'ed with
234235

235236
5. Adding vec
236237

237-
#shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}>
238+
#ttg.swizzled_shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}>
238239
[ 0, 1, 2, 3, 4, 5, 6, 7],
239240
[10, 11, 8, 9, 14, 15, 12, 13],
240241
[20, 21, 22, 23, 16, 17, 18, 19],
@@ -383,6 +384,88 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
383384
let genVerifyDecl = 1;
384385
}
385386

387+
def PaddeddSharedEncodingAttr
388+
: TritonGPU_Attr<"PaddedSharedEncoding", "padded_shared_encoding",
389+
[SharedEncodingTrait, LayoutEncodingTrait]> {
390+
let mnemonic = "padded_shared";
391+
392+
let description = [{
393+
An encoding for tensors whose elements may be simultaneously accessed by
394+
different GPU threads in the programs, via shared memory. In other words,
395+
for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
396+
Compared to SwizzledSharedEncodingAttr, this encoding uses padding to avoid
397+
shared memory bank conflicts.
398+
399+
Formally, given a layout:
400+
padded_shared<[<interval_0>:+<pad_0>, <interval_1>:+<pad_1>, ...]>
401+
We insert a padding of `<pad_i>` elements after every `<interval_i>` elements.
402+
Multi interval-padding pairs are supported for flexibility of multi tiered
403+
padding schemes; they compose in an additive manner. So for a 1-D tensor element
404+
at index i, the corresponding shared memory location index is
405+
i + \sum_{k} (i / interval_k) * pad_k = 1
406+
`<interval_i>` and `<pad_i>` all need to be power of two.
407+
408+
Some concrete examples, using `eM` to mean tensor elements and `pN` to mean
409+
padding:
410+
411+
1. Single interval-padding pair:
412+
413+
#ttg.padded_shared<[2:+2]>
414+
[e0, e1, p0, p1,
415+
e2, e3, p2, p3,
416+
...]
417+
418+
2. Double interval-padding pairs:
419+
420+
#ttg.padded_shared<[2:+1, 4:+2]>
421+
[e0, e1, p0,
422+
e2, e3, p1, p2, p3,
423+
e4, e5, p4,
424+
e6, e7, p5, p6, p7,
425+
...]
426+
427+
In addition to interval-padding pairs, this encoding requires an `order` to
428+
specify the logical tensor dimenions from the fastest-to slowest-varying.
429+
It may optionally support CGA level organization like other encoding
430+
attributes too, for example,
431+
#ttg.padded_shared<[2:+1, 4:+2] {
432+
order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1],
433+
CTAOrder = [0, 1]}>
434+
}];
435+
436+
let parameters = (ins
437+
ArrayRefParameter<"unsigned">:$intervals,
438+
ArrayRefParameter<"unsigned">:$paddings,
439+
// Order of logical tensor dimensions; fastest-varying first.
440+
ArrayRefParameter<"unsigned">:$order,
441+
"CTALayoutAttr":$CTALayout
442+
);
443+
444+
let builders = [
445+
AttrBuilder<(ins "ArrayRef<std::pair<unsigned, unsigned>>":$intervalPads,
446+
"ArrayRef<unsigned>":$order, "CTALayoutAttr":$ctaLayout)>,
447+
];
448+
449+
let extraClassDeclaration = extraBaseClassDeclaration # [{
450+
unsigned getRank() const { return getOrder().size(); }
451+
int32_t getAlignment() const { return 16; }
452+
453+
unsigned getMinInterval() const {
454+
return *llvm::min_element(getIntervals());
455+
}
456+
457+
// Returns the total number of elements including padding given the input
458+
// tensor shape.
459+
int64_t getPaddedSize(ArrayRef<int64_t> shape) const;
460+
461+
SmallVector<unsigned> getCTAsPerCGA() const;
462+
SmallVector<unsigned> getCTAOrder() const;
463+
SmallVector<unsigned> getCTASplitNum() const;
464+
}];
465+
let hasCustomAssemblyFormat = 1;
466+
let genVerifyDecl = 1;
467+
}
468+
386469
def NVMMASharedEncodingAttr :
387470
TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
388471
let mnemonic = "nvmma_shared";

lib/Analysis/Allocation.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,17 @@ class AllocationAnalysis {
260260
auto alloc = dyn_cast<gpu::LocalAllocOp>(op);
261261
if (!alloc || !alloc.isSharedMemoryAlloc())
262262
return;
263-
// Bytes could be a different value once we support padding or other
264-
// allocation policies.
265263
auto allocType = alloc.getType();
266-
auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType);
267-
auto bytes =
268-
product<int64_t>(shapePerCTA) * allocType.getElementTypeBitWidth() / 8;
264+
int64_t numElems = 0;
265+
if (auto paddedLayout =
266+
dyn_cast<gpu::PaddedSharedEncodingAttr>(allocType.getEncoding())) {
267+
SmallVector<int64_t> unpaddedShape = gpu::getShapePerCTA(allocType);
268+
numElems = paddedLayout.getPaddedSize(unpaddedShape);
269+
} else {
270+
auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType);
271+
numElems = product<int64_t>(shapePerCTA);
272+
}
273+
int64_t bytes = numElems * allocType.getElementTypeBitWidth() / 8;
269274

270275
auto alignment = alloc.getAlignmentOrDefault();
271276
allocation->addBuffer<BufferT::BufferKind::Explicit>(alloc, bytes,

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
99
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1010
#include "triton/Tools/LayoutUtils.h"
11+
#include "triton/Tools/LinearLayout.h"
1112
#include "llvm/ADT/STLExtras.h"
13+
#include "llvm/Support/MathExtras.h"
1214

1315
#if defined(_MSC_VER) && !defined(__clang__)
1416
// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0
@@ -407,6 +409,10 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
407409
// We propose case 2 (see comments below), which provides a more general
408410
// solution for all swizzled shared memory scenarios, including the edge case
409411
// mentioned above.
412+
//
413+
// Padded shared layout falls into case 1--we can rely on the logic for case 1
414+
// to get the 1-D offset into shared memory. Then we just need to add the
415+
// padding offset.
410416
if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1
411417
smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout,
412418
{{kRegister, regId},
@@ -435,6 +441,18 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
435441
smemOffset = dot(rewriter, loc, smemOffsets,
436442
applyPermutation(smemStrides, smemOrder));
437443
}
444+
if (auto paddedLayout =
445+
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc)) {
446+
// Apply the offset needed for padding.
447+
Value padOffset = b.i32_val(0);
448+
for (auto [interval, padding] : llvm::zip_equal(
449+
paddedLayout.getIntervals(), paddedLayout.getPaddings())) {
450+
Value iVal = b.i32_val(llvm::Log2_32(interval));
451+
Value pVal = b.i32_val(llvm::Log2_32(padding));
452+
padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal));
453+
}
454+
smemOffset = b.add(smemOffset, padOffset);
455+
}
438456
} else { // Case 2 -> rank-reduced swizzling
439457
assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2");
440458
assert((isa<triton::gpu::SwizzledSharedEncodingAttr,
@@ -627,17 +645,6 @@ SmallVector<Value> lowerLocalLdSt(Location loc, MLIRContext *ctx,
627645
rewriter, targetInfo);
628646
}
629647

630-
bool emitTransferBetweenRegistersAndShared(
631-
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
632-
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
633-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
634-
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
635-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
636-
return emitTransferBetweenRegistersAndShared(
637-
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
638-
target, laneId, warpId, perVectorCallback);
639-
}
640-
641648
bool emitTransferBetweenRegistersAndShared(
642649
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
643650
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
@@ -651,11 +658,19 @@ bool emitTransferBetweenRegistersAndShared(
651658
StringAttr kRegister = str_attr("register");
652659
StringAttr kLane = str_attr("lane");
653660
StringAttr kWarp = str_attr("warp");
661+
StringAttr kOffset = str_attr("offset");
654662

655663
auto shape = sharedTy.getShape();
656-
LinearLayout sharedLayout =
657-
triton::gpu::toLinearLayout(shape, sharedTy.getEncoding());
658-
LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout);
664+
auto paddedLayout =
665+
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedTy.getEncoding());
666+
LinearLayout regToSharedLayout = LinearLayout::empty();
667+
if (paddedLayout) {
668+
regToSharedLayout =
669+
regLayout.reshapeOuts({{kOffset, regLayout.getTotalOutDimSize()}});
670+
} else {
671+
auto sharedLL = triton::gpu::toLinearLayout(shape, sharedTy.getEncoding());
672+
regToSharedLayout = regLayout.invertAndCompose(sharedLL);
673+
}
659674

660675
// TODO(jlebar): We don't currently support loading from shared memory in a
661676
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
@@ -680,9 +695,12 @@ bool emitTransferBetweenRegistersAndShared(
680695
//
681696
// It's OK if the vector width we choose here is wider than the hardware
682697
// supports; LLVM will legalize it.
683-
const int vecElems =
684-
std::min(regToSharedLayout.getNumConsecutiveInOut(),
685-
maxVecElems.value_or(std::numeric_limits<int>::max()));
698+
int vecElems =
699+
std::min({regToSharedLayout.getNumConsecutiveInOut(),
700+
maxVecElems.value_or(std::numeric_limits<int>::max())});
701+
if (paddedLayout) {
702+
vecElems = std::min(vecElems, int(paddedLayout.getMinInterval()));
703+
}
686704

687705
auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1;
688706
Value blockId =
@@ -696,10 +714,14 @@ bool emitTransferBetweenRegistersAndShared(
696714
// take out the "block" dimension.
697715
// Thus we use `pseudoinvert` instead of `invert` here for simplicity.
698716
auto allocShape = sharedTy.getAllocShape();
699-
LinearLayout invertAllocSharedLayout =
700-
triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()),
701-
sharedTy.getEncoding())
702-
.pseudoinvert();
717+
auto invertAllocSharedLayout = LinearLayout::empty();
718+
if (!paddedLayout) {
719+
// For now this is only needed for the cases where we have swizzling.
720+
invertAllocSharedLayout =
721+
triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()),
722+
sharedTy.getEncoding())
723+
.pseudoinvert();
724+
}
703725

704726
int numElems = regToSharedLayout.getInDimSize(kRegister);
705727
auto vecTy = vec_ty(elemLlvmTy, vecElems);
@@ -722,9 +744,10 @@ bool emitTransferBetweenRegistersAndShared(
722744
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
723745
auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(),
724746
registerTy.getEncoding());
747+
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
725748
return emitTransferBetweenRegistersAndShared(
726749
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
727-
target, perVectorCallback);
750+
target, laneId, warpId, perVectorCallback);
728751
}
729752

730753
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
@@ -912,10 +935,13 @@ bool isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
912935
ArrayRef<int64_t> allocShape,
913936
triton::gpu::SharedEncodingTrait sharedEnc) {
914937
auto rank = shape.size();
938+
auto paddedLayout =
939+
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc);
915940
auto swizzledLayout =
916941
dyn_cast<triton::gpu::SwizzledSharedEncodingAttr>(sharedEnc);
917942
auto nvmmaLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(sharedEnc);
918-
bool noSwizzling = (swizzledLayout && swizzledLayout.getMaxPhase() == 1) ||
943+
bool noSwizzling = paddedLayout ||
944+
(swizzledLayout && swizzledLayout.getMaxPhase() == 1) ||
919945
(nvmmaLayout && nvmmaLayout.getSwizzlingByteWidth() == 0);
920946
return /*no swizzling*/ noSwizzling ||
921947
/*swizzling but same shape*/ shape == allocShape ||

0 commit comments

Comments
 (0)