Skip to content

Commit 741a71f

Browse files
Merge commit 'de1f346aa6737fa2e3e6a8a64dae118fcfab9995'
2 parents 1b88a41 + de1f346 commit 741a71f

File tree

14 files changed

+785
-197
lines changed

14 files changed

+785
-197
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,71 @@ triton::gpu::BlockedEncodingAttr
151151
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
152152
int numWarps, int threadsPerWarp, int numCTAs);
153153

154+
// For each output dimension d, ensure that the layout's output size (i.e., its
155+
// codomain) does not exceed shape[d]. Do this without changing the size of the
156+
// layout's inputs (i.e., leave its domain unchanged).
157+
//
158+
// This function is invariant to the order of the layout's input and output
159+
// dimensions.
160+
//
161+
// We achieve this by setting the largest value in each output dimension d to 0
162+
// because bases that map to a location larger than shape[d]
163+
// effectively duplicate along that dimension. For example, consider a layout
164+
// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to
165+
// shrink the output dimension size to 8:
166+
//
167+
// L(register=1) = 8
168+
// L(register=2) = 4
169+
// L(register=4) = 1
170+
// L(lane=1) = 2
171+
// L(lane=2) = 16
172+
//
173+
// In the first step, we shrink the output dimension size to 16 by setting
174+
// L(lane=2) to 0:
175+
//
176+
// L(register=1) = 8
177+
// L(register=2) = 4
178+
// L(register=4) = 1
179+
// L(lane=1) = 2
180+
// L(lane=2) = 0
181+
//
182+
// This means that lane=2 has the same data as lane=0.
183+
//
184+
// Now the output dimension of this layout has a size of 16, which is still
185+
// larger than 8. We find the current largest value in the output dimension,
186+
// which is L(register=1) = 8, and we set L(register=1) to 0:
187+
//
188+
// L(register=1) = 0
189+
// L(register=2) = 4
190+
// L(register=4) = 1
191+
// L(lane=1) = 2
192+
// L(lane=2) = 0
193+
//
194+
// Now the output dimension of this layout has a size of 8, which is the desired
195+
// size. Note that this method works only because the bases are powers of two,
196+
// which is the case for DistributedLayouts If broadcastRegisters is false, we
197+
// remove any register that's larger than the desired shape. In the example
198+
// above we would have
199+
// L(register=1) = 4
200+
// L(register=2) = 1
201+
// L(lane=1) = 2
202+
// L(lane=2) = 0
203+
LinearLayout
204+
ensureLayoutNotLargerThan(const LinearLayout &layout,
205+
const llvm::SmallDenseMap<StringAttr, int64_t> &shape,
206+
bool broadcastRegisters = true);
207+
208+
// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no
209+
// smaller than shape[d]. Do this by increasing the size of the layout's inputs
210+
// along its most-minor dimension ("register" for register layouts, "offset" for
211+
// shared layouts).
212+
//
213+
// This function is invariant to the order of the layout's input dimensions, but
214+
// it cares about the order of the output dims, which should be minor-to-major.
215+
LinearLayout ensureLayoutNotSmallerThan(
216+
const LinearLayout &layout,
217+
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);
218+
154219
// Dump information about which threads/registers contain each of the tensor
155220
// elements.
156221
void dumpLayout(RankedTensorType tensorType);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute
5656
code extraBaseClassDeclaration = [{
5757
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
5858
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
59-
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
6059
}];
6160
}
6261

@@ -147,7 +146,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
147146
let genVerifyDecl = 1;
148147
let skipDefaultBuilders = 1;
149148
}
150-
151149
//===----------------------------------------------------------------------===//
152150
// Shared Layout Encoding
153151
//===----------------------------------------------------------------------===//
@@ -571,6 +569,34 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
571569
}];
572570
}
573571

572+
//===----------------------------------------------------------------------===//
573+
// Linear Layout Encoding
574+
//===----------------------------------------------------------------------===//
575+
576+
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
577+
let mnemonic = "linear";
578+
579+
let description = [{
580+
See the docs in LinearLayout.h for the definition of linear layouts.
581+
}];
582+
583+
let parameters = (ins "LinearLayout":$linearLayout);
584+
585+
let extraClassDeclaration = extraDistributedDeclaration # [{
586+
SmallVector<unsigned> getContigPerThread() const;
587+
SmallVector<unsigned> getOrder() const;
588+
}];
589+
590+
let genVerifyDecl = 1;
591+
// Example of assembly format:
592+
// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]],
593+
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
594+
// warp = [[16, 0], [32, 0]],
595+
// block = []}>
596+
let hasCustomAssemblyFormat = 1;
597+
}
598+
599+
574600
//===----------------------------------------------------------------------===//
575601
// Blocked Layout Encoding
576602
//===----------------------------------------------------------------------===//

include/triton/Tools/LinearLayout.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <vector>
1010

1111
#include "mlir/IR/BuiltinAttributes.h"
12+
#include "llvm/ADT/Hashing.h"
1213
#include "llvm/ADT/MapVector.h"
1314
#include "llvm/ADT/STLExtras.h"
1415
#include "llvm/ADT/SetVector.h"
@@ -432,6 +433,7 @@ class LinearLayout {
432433
// (e.g. by reshaping) then the order doesn't really affect anything.
433434
auto getInDimNames() const { return llvm::make_first_range(bases); }
434435
auto getOutDimNames() const { return llvm::make_first_range(outDims); }
436+
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }
435437

436438
// Gets the position that this outDim occupies in getOutDimNames(). Asserts
437439
// if the dim is not present.
@@ -693,6 +695,7 @@ class LinearLayout {
693695
return !(lhs == rhs);
694696
}
695697
bool equalIgnoringOutDimSizes(const LinearLayout &other) const;
698+
friend size_t hash_value(const LinearLayout &layout);
696699

697700
private:
698701
// Factory function that gracefully fails rather than asserts if the layout is

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
397397
if (isa<BlockedEncodingAttr>(layout)) {
398398
return true;
399399
}
400+
if (isa<LinearEncodingAttr>(layout)) {
401+
return true;
402+
}
400403
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
401404
return layoutIsOK(slice.getParent());
402405
}

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
165165
Attribute srcLayout = srcTy.getEncoding();
166166
Attribute dstLayout = dstTy.getEncoding();
167167
if (isa<SharedEncodingAttr>(srcLayout) &&
168-
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
169-
dstLayout) ||
168+
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
169+
LinearEncodingAttr>(dstLayout) ||
170170
isSupportedDotOpLayout(dstTy))) {
171171
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
172172
rewriter);
@@ -206,7 +206,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
206206
auto dstTy = op.getResult().getType();
207207
auto dstShape = dstTy.getShape();
208208
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
209-
auto dstLayout = dstTy.getEncoding();
210209
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
211210
"Unexpected rank of ConvertLayout(shared->distributed)");
212211

0 commit comments

Comments
 (0)