Skip to content

Commit 5bbce9e

Browse files
Revert "Revert "[LAYOUTS] Implement IR support for LinearLayouts (#5170)""
This reverts commit 7b5daa4.
1 parent b5a791e commit 5bbce9e

File tree

14 files changed

+785
-197
lines changed

14 files changed

+785
-197
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,71 @@ triton::gpu::BlockedEncodingAttr
149149
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
150150
int numWarps, int threadsPerWarp, int numCTAs);
151151

152+
// For each output dimension d, ensure that the layout's output size (i.e., its
153+
// codomain) does not exceed shape[d]. Do this without changing the size of the
154+
// layout's inputs (i.e., leave its domain unchanged).
155+
//
156+
// This function is invariant to the order of the layout's input and output
157+
// dimensions.
158+
//
159+
// We achieve this by setting the largest value in each output dimension d to 0
160+
// because bases that map to a location larger than shape[d]
161+
// effectively duplicate along that dimension. For example, consider a layout
162+
// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to
163+
// shrink the output dimension size to 8:
164+
//
165+
// L(register=1) = 8
166+
// L(register=2) = 4
167+
// L(register=4) = 1
168+
// L(lane=1) = 2
169+
// L(lane=2) = 16
170+
//
171+
// In the first step, we shrink the output dimension size to 16 by setting
172+
// L(lane=2) to 0:
173+
//
174+
// L(register=1) = 8
175+
// L(register=2) = 4
176+
// L(register=4) = 1
177+
// L(lane=1) = 2
178+
// L(lane=2) = 0
179+
//
180+
// This means that lane=2 has the same data as lane=0.
181+
//
182+
// Now the output dimension of this layout has a size of 16, which is still
183+
// larger than 8. We find the current largest value in the output dimension,
184+
// which is L(register=1) = 8, and we set L(register=1) to 0:
185+
//
186+
// L(register=1) = 0
187+
// L(register=2) = 4
188+
// L(register=4) = 1
189+
// L(lane=1) = 2
190+
// L(lane=2) = 0
191+
//
192+
// Now the output dimension of this layout has a size of 8, which is the desired
193+
// size. Note that this method works only because the bases are powers of two,
194+
// which is the case for DistributedLayouts If broadcastRegisters is false, we
195+
// remove any register that's larger than the desired shape. In the example
196+
// above we would have
197+
// L(register=1) = 4
198+
// L(register=2) = 1
199+
// L(lane=1) = 2
200+
// L(lane=2) = 0
201+
LinearLayout
202+
ensureLayoutNotLargerThan(const LinearLayout &layout,
203+
const llvm::SmallDenseMap<StringAttr, int64_t> &shape,
204+
bool broadcastRegisters = true);
205+
206+
// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no
207+
// smaller than shape[d]. Do this by increasing the size of the layout's inputs
208+
// along its most-minor dimension ("register" for register layouts, "offset" for
209+
// shared layouts).
210+
//
211+
// This function is invariant to the order of the layout's input dimensions, but
212+
// it cares about the order of the output dims, which should be minor-to-major.
213+
LinearLayout ensureLayoutNotSmallerThan(
214+
const LinearLayout &layout,
215+
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);
216+
152217
// Dump information about which threads/registers contain each of the tensor
153218
// elements.
154219
void dumpLayout(RankedTensorType tensorType);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute
5656
code extraBaseClassDeclaration = [{
5757
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
5858
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
59-
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
6059
}];
6160
}
6261

@@ -147,7 +146,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
147146
let genVerifyDecl = 1;
148147
let skipDefaultBuilders = 1;
149148
}
150-
151149
//===----------------------------------------------------------------------===//
152150
// Shared Layout Encoding
153151
//===----------------------------------------------------------------------===//
@@ -565,6 +563,34 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
565563
}];
566564
}
567565

566+
//===----------------------------------------------------------------------===//
567+
// Linear Layout Encoding
568+
//===----------------------------------------------------------------------===//
569+
570+
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
571+
let mnemonic = "linear";
572+
573+
let description = [{
574+
See the docs in LinearLayout.h for the definition of linear layouts.
575+
}];
576+
577+
let parameters = (ins "LinearLayout":$linearLayout);
578+
579+
let extraClassDeclaration = extraDistributedDeclaration # [{
580+
SmallVector<unsigned> getContigPerThread() const;
581+
SmallVector<unsigned> getOrder() const;
582+
}];
583+
584+
let genVerifyDecl = 1;
585+
// Example of assembly format:
586+
// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]],
587+
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
588+
// warp = [[16, 0], [32, 0]],
589+
// block = []}>
590+
let hasCustomAssemblyFormat = 1;
591+
}
592+
593+
568594
//===----------------------------------------------------------------------===//
569595
// Blocked Layout Encoding
570596
//===----------------------------------------------------------------------===//

include/triton/Tools/LinearLayout.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <vector>
1010

1111
#include "mlir/IR/BuiltinAttributes.h"
12+
#include "llvm/ADT/Hashing.h"
1213
#include "llvm/ADT/MapVector.h"
1314
#include "llvm/ADT/STLExtras.h"
1415
#include "llvm/ADT/SetVector.h"
@@ -432,6 +433,7 @@ class LinearLayout {
432433
// (e.g. by reshaping) then the order doesn't really affect anything.
433434
auto getInDimNames() const { return llvm::make_first_range(bases); }
434435
auto getOutDimNames() const { return llvm::make_first_range(outDims); }
436+
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }
435437

436438
// Gets the position that this outDim occupies in getOutDimNames(). Asserts
437439
// if the dim is not present.
@@ -693,6 +695,7 @@ class LinearLayout {
693695
return !(lhs == rhs);
694696
}
695697
bool equalIgnoringOutDimSizes(const LinearLayout &other) const;
698+
friend size_t hash_value(const LinearLayout &layout);
696699

697700
private:
698701
// Factory function that gracefully fails rather than asserts if the layout is

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
397397
if (isa<BlockedEncodingAttr>(layout)) {
398398
return true;
399399
}
400+
if (isa<LinearEncodingAttr>(layout)) {
401+
return true;
402+
}
400403
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
401404
return layoutIsOK(slice.getParent());
402405
}

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
165165
Attribute srcLayout = srcTy.getEncoding();
166166
Attribute dstLayout = dstTy.getEncoding();
167167
if (isa<SharedEncodingAttr>(srcLayout) &&
168-
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
169-
dstLayout) ||
168+
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
169+
LinearEncodingAttr>(dstLayout) ||
170170
isSupportedDotOpLayout(dstTy))) {
171171
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
172172
rewriter);
@@ -206,7 +206,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
206206
auto dstTy = op.getResult().getType();
207207
auto dstShape = dstTy.getShape();
208208
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
209-
auto dstLayout = dstTy.getEncoding();
210209
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
211210
"Unexpected rank of ConvertLayout(shared->distributed)");
212211

0 commit comments

Comments
 (0)