Skip to content

Commit 7b5daa4

Browse files
Revert "[LAYOUTS] Implement IR support for LinearLayouts (#5170)"
This reverts commit de1f346.
1 parent 741a71f commit 7b5daa4

File tree

14 files changed

+197
-785
lines changed

14 files changed

+197
-785
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -151,71 +151,6 @@ triton::gpu::BlockedEncodingAttr
151151
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
152152
int numWarps, int threadsPerWarp, int numCTAs);
153153

154-
// For each output dimension d, ensure that the layout's output size (i.e., its
155-
// codomain) does not exceed shape[d]. Do this without changing the size of the
156-
// layout's inputs (i.e., leave its domain unchanged).
157-
//
158-
// This function is invariant to the order of the layout's input and output
159-
// dimensions.
160-
//
161-
// We achieve this by setting the largest value in each output dimension d to 0
162-
// because bases that map to a location larger than shape[d]
163-
// effectively duplicate along that dimension. For example, consider a layout
164-
// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to
165-
// shrink the output dimension size to 8:
166-
//
167-
// L(register=1) = 8
168-
// L(register=2) = 4
169-
// L(register=4) = 1
170-
// L(lane=1) = 2
171-
// L(lane=2) = 16
172-
//
173-
// In the first step, we shrink the output dimension size to 16 by setting
174-
// L(lane=2) to 0:
175-
//
176-
// L(register=1) = 8
177-
// L(register=2) = 4
178-
// L(register=4) = 1
179-
// L(lane=1) = 2
180-
// L(lane=2) = 0
181-
//
182-
// This means that lane=2 has the same data as lane=0.
183-
//
184-
// Now the output dimension of this layout has a size of 16, which is still
185-
// larger than 8. We find the current largest value in the output dimension,
186-
// which is L(register=1) = 8, and we set L(register=1) to 0:
187-
//
188-
// L(register=1) = 0
189-
// L(register=2) = 4
190-
// L(register=4) = 1
191-
// L(lane=1) = 2
192-
// L(lane=2) = 0
193-
//
194-
// Now the output dimension of this layout has a size of 8, which is the desired
195-
// size. Note that this method works only because the bases are powers of two,
196-
// which is the case for DistributedLayouts If broadcastRegisters is false, we
197-
// remove any register that's larger than the desired shape. In the example
198-
// above we would have
199-
// L(register=1) = 4
200-
// L(register=2) = 1
201-
// L(lane=1) = 2
202-
// L(lane=2) = 0
203-
LinearLayout
204-
ensureLayoutNotLargerThan(const LinearLayout &layout,
205-
const llvm::SmallDenseMap<StringAttr, int64_t> &shape,
206-
bool broadcastRegisters = true);
207-
208-
// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no
209-
// smaller than shape[d]. Do this by increasing the size of the layout's inputs
210-
// along its most-minor dimension ("register" for register layouts, "offset" for
211-
// shared layouts).
212-
//
213-
// This function is invariant to the order of the layout's input dimensions, but
214-
// it cares about the order of the output dims, which should be minor-to-major.
215-
LinearLayout ensureLayoutNotSmallerThan(
216-
const LinearLayout &layout,
217-
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);
218-
219154
// Dump information about which threads/registers contain each of the tensor
220155
// elements.
221156
void dumpLayout(RankedTensorType tensorType);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Right now, Triton implements two main classes of layouts: shared, and distribute
5656
code extraBaseClassDeclaration = [{
5757
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
5858
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
59+
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
5960
}];
6061
}
6162

@@ -146,6 +147,7 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
146147
let genVerifyDecl = 1;
147148
let skipDefaultBuilders = 1;
148149
}
150+
149151
//===----------------------------------------------------------------------===//
150152
// Shared Layout Encoding
151153
//===----------------------------------------------------------------------===//
@@ -569,34 +571,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
569571
}];
570572
}
571573

572-
//===----------------------------------------------------------------------===//
573-
// Linear Layout Encoding
574-
//===----------------------------------------------------------------------===//
575-
576-
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
577-
let mnemonic = "linear";
578-
579-
let description = [{
580-
See the docs in LinearLayout.h for the definition of linear layouts.
581-
}];
582-
583-
let parameters = (ins "LinearLayout":$linearLayout);
584-
585-
let extraClassDeclaration = extraDistributedDeclaration # [{
586-
SmallVector<unsigned> getContigPerThread() const;
587-
SmallVector<unsigned> getOrder() const;
588-
}];
589-
590-
let genVerifyDecl = 1;
591-
// Example of assembly format:
592-
// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]],
593-
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
594-
// warp = [[16, 0], [32, 0]],
595-
// block = []}>
596-
let hasCustomAssemblyFormat = 1;
597-
}
598-
599-
600574
//===----------------------------------------------------------------------===//
601575
// Blocked Layout Encoding
602576
//===----------------------------------------------------------------------===//

include/triton/Tools/LinearLayout.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <vector>
1010

1111
#include "mlir/IR/BuiltinAttributes.h"
12-
#include "llvm/ADT/Hashing.h"
1312
#include "llvm/ADT/MapVector.h"
1413
#include "llvm/ADT/STLExtras.h"
1514
#include "llvm/ADT/SetVector.h"
@@ -433,7 +432,6 @@ class LinearLayout {
433432
// (e.g. by reshaping) then the order doesn't really affect anything.
434433
auto getInDimNames() const { return llvm::make_first_range(bases); }
435434
auto getOutDimNames() const { return llvm::make_first_range(outDims); }
436-
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }
437435

438436
// Gets the position that this outDim occupies in getOutDimNames(). Asserts
439437
// if the dim is not present.
@@ -695,7 +693,6 @@ class LinearLayout {
695693
return !(lhs == rhs);
696694
}
697695
bool equalIgnoringOutDimSizes(const LinearLayout &other) const;
698-
friend size_t hash_value(const LinearLayout &layout);
699696

700697
private:
701698
// Factory function that gracefully fails rather than asserts if the layout is

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
397397
if (isa<BlockedEncodingAttr>(layout)) {
398398
return true;
399399
}
400-
if (isa<LinearEncodingAttr>(layout)) {
401-
return true;
402-
}
403400
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
404401
return layoutIsOK(slice.getParent());
405402
}

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
165165
Attribute srcLayout = srcTy.getEncoding();
166166
Attribute dstLayout = dstTy.getEncoding();
167167
if (isa<SharedEncodingAttr>(srcLayout) &&
168-
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
169-
LinearEncodingAttr>(dstLayout) ||
168+
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
169+
dstLayout) ||
170170
isSupportedDotOpLayout(dstTy))) {
171171
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
172172
rewriter);
@@ -206,6 +206,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
206206
auto dstTy = op.getResult().getType();
207207
auto dstShape = dstTy.getShape();
208208
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
209+
auto dstLayout = dstTy.getEncoding();
209210
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
210211
"Unexpected rank of ConvertLayout(shared->distributed)");
211212

0 commit comments

Comments
 (0)