Skip to content

Commit de1f346

Browse files
authored
[LAYOUTS] Implement IR support for LinearLayouts (#5170)
We also exercise this in scale_dot, where we enable support for warps of arbitrary shape (before we just allowed `[num_warps, 1]`). With this infra in place, it should be rather easy to move from the legacy layouts to using LLs to represent all of our layouts. Something I'm concerned about is the amount of recomputation that happens when calling methods like `getSizePerThread` and the like, where we keep recomputing the result. There might be an optimisation opportunity here where we cache the result of all these functions. We choose the IR representation of an LL via its canonical form + a `repOrder` for several reasons: - It's generally more compact - It's easier to CSE, so it's easier to see when two layouts are in fact the same. - A technical reason: the `toLinearLayout` function returns a tensor with dimensions `dim0, ..., dim<rank-1>`, in other words, it "forgets" the repetition order. Without the repetition order, we cannot recover the tile size of the argument. In particular, we cannot recover `getSizePerThread`. There is an argument to be made about whether `getSizePerThread` is useful on its own, or whether it is `getElemsPerThread` the real useful abstraction here, but for now, we keep both for BC.
1 parent 66012fc commit de1f346

File tree

14 files changed

+785
-197
lines changed

14 files changed

+785
-197
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,71 @@ triton::gpu::BlockedEncodingAttr
149149
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
150150
int numWarps, int threadsPerWarp, int numCTAs);
151151

152+
// For each output dimension d, ensure that the layout's output size (i.e., its
153+
// codomain) does not exceed shape[d]. Do this without changing the size of the
154+
// layout's inputs (i.e., leave its domain unchanged).
155+
//
156+
// This function is invariant to the order of the layout's input and output
157+
// dimensions.
158+
//
159+
// We achieve this by setting the largest value in each output dimension d to 0
160+
// because bases that map to a location larger than shape[d]
161+
// effectively duplicate along that dimension. For example, consider a layout
162+
// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to
163+
// shrink the output dimension size to 8:
164+
//
165+
// L(register=1) = 8
166+
// L(register=2) = 4
167+
// L(register=4) = 1
168+
// L(lane=1) = 2
169+
// L(lane=2) = 16
170+
//
171+
// In the first step, we shrink the output dimension size to 16 by setting
172+
// L(lane=2) to 0:
173+
//
174+
// L(register=1) = 8
175+
// L(register=2) = 4
176+
// L(register=4) = 1
177+
// L(lane=1) = 2
178+
// L(lane=2) = 0
179+
//
180+
// This means that lane=2 has the same data as lane=0.
181+
//
182+
// Now the output dimension of this layout has a size of 16, which is still
183+
// larger than 8. We find the current largest value in the output dimension,
184+
// which is L(register=1) = 8, and we set L(register=1) to 0:
185+
//
186+
// L(register=1) = 0
187+
// L(register=2) = 4
188+
// L(register=4) = 1
189+
// L(lane=1) = 2
190+
// L(lane=2) = 0
191+
//
192+
// Now the output dimension of this layout has a size of 8, which is the desired
193+
// size. Note that this method works only because the bases are powers of two,
194+
// which is the case for DistributedLayouts If broadcastRegisters is false, we
195+
// remove any register that's larger than the desired shape. In the example
196+
// above we would have
197+
// L(register=1) = 4
198+
// L(register=2) = 1
199+
// L(lane=1) = 2
200+
// L(lane=2) = 0
201+
LinearLayout
202+
ensureLayoutNotLargerThan(const LinearLayout &layout,
203+
const llvm::SmallDenseMap<StringAttr, int64_t> &shape,
204+
bool broadcastRegisters = true);
205+
206+
// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no
207+
// smaller than shape[d]. Do this by increasing the size of the layout's inputs
208+
// along its most-minor dimension ("register" for register layouts, "offset" for
209+
// shared layouts).
210+
//
211+
// This function is invariant to the order of the layout's input dimensions, but
212+
// it cares about the order of the output dims, which should be minor-to-major.
213+
LinearLayout ensureLayoutNotSmallerThan(
214+
const LinearLayout &layout,
215+
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);
216+
152217
// Dump information about which threads/registers contain each of the tensor
153218
// elements.
154219
void dumpLayout(RankedTensorType tensorType);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute
5656
code extraBaseClassDeclaration = [{
5757
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
5858
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
59-
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
6059
}];
6160
}
6261

@@ -147,7 +146,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
147146
let genVerifyDecl = 1;
148147
let skipDefaultBuilders = 1;
149148
}
150-
151149
//===----------------------------------------------------------------------===//
152150
// Shared Layout Encoding
153151
//===----------------------------------------------------------------------===//
@@ -565,6 +563,34 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
565563
}];
566564
}
567565

566+
//===----------------------------------------------------------------------===//
567+
// Linear Layout Encoding
568+
//===----------------------------------------------------------------------===//
569+
570+
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
571+
let mnemonic = "linear";
572+
573+
let description = [{
574+
See the docs in LinearLayout.h for the definition of linear layouts.
575+
}];
576+
577+
let parameters = (ins "LinearLayout":$linearLayout);
578+
579+
let extraClassDeclaration = extraDistributedDeclaration # [{
580+
SmallVector<unsigned> getContigPerThread() const;
581+
SmallVector<unsigned> getOrder() const;
582+
}];
583+
584+
let genVerifyDecl = 1;
585+
// Example of assembly format:
586+
// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]],
587+
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
588+
// warp = [[16, 0], [32, 0]],
589+
// block = []}>
590+
let hasCustomAssemblyFormat = 1;
591+
}
592+
593+
568594
//===----------------------------------------------------------------------===//
569595
// Blocked Layout Encoding
570596
//===----------------------------------------------------------------------===//

include/triton/Tools/LinearLayout.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <vector>
1010

1111
#include "mlir/IR/BuiltinAttributes.h"
12+
#include "llvm/ADT/Hashing.h"
1213
#include "llvm/ADT/MapVector.h"
1314
#include "llvm/ADT/STLExtras.h"
1415
#include "llvm/ADT/SetVector.h"
@@ -432,6 +433,7 @@ class LinearLayout {
432433
// (e.g. by reshaping) then the order doesn't really affect anything.
433434
auto getInDimNames() const { return llvm::make_first_range(bases); }
434435
auto getOutDimNames() const { return llvm::make_first_range(outDims); }
436+
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }
435437

436438
// Gets the position that this outDim occupies in getOutDimNames(). Asserts
437439
// if the dim is not present.
@@ -693,6 +695,7 @@ class LinearLayout {
693695
return !(lhs == rhs);
694696
}
695697
bool equalIgnoringOutDimSizes(const LinearLayout &other) const;
698+
friend size_t hash_value(const LinearLayout &layout);
696699

697700
private:
698701
// Factory function that gracefully fails rather than asserts if the layout is

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
397397
if (isa<BlockedEncodingAttr>(layout)) {
398398
return true;
399399
}
400+
if (isa<LinearEncodingAttr>(layout)) {
401+
return true;
402+
}
400403
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
401404
return layoutIsOK(slice.getParent());
402405
}

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
165165
Attribute srcLayout = srcTy.getEncoding();
166166
Attribute dstLayout = dstTy.getEncoding();
167167
if (isa<SharedEncodingAttr>(srcLayout) &&
168-
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
169-
dstLayout) ||
168+
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
169+
LinearEncodingAttr>(dstLayout) ||
170170
isSupportedDotOpLayout(dstTy))) {
171171
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
172172
rewriter);
@@ -206,7 +206,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
206206
auto dstTy = op.getResult().getType();
207207
auto dstShape = dstTy.getShape();
208208
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
209-
auto dstLayout = dstTy.getEncoding();
210209
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
211210
"Unexpected rank of ConvertLayout(shared->distributed)");
212211

0 commit comments

Comments
 (0)