Skip to content

Commit 1e0e51c

Browse files
authored
[LAYOUTS] Remove HoistLayoutConversion in favour of backwardsRemat (#5788)
Reland of triton-lang/triton#5673
1 parent ca582a2 commit 1e0e51c

File tree

9 files changed

+468
-422
lines changed

9 files changed

+468
-422
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ def TT_TransOp : TT_Op<"trans", [Pure,
581581
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
582582

583583
let hasFolder = 1;
584+
let hasVerifier = 1;
584585
}
585586

586587
//
@@ -830,7 +831,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
830831
def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [
831832
Elementwise,
832833
SameOperandsAndResultEncoding,
833-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
834+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
835+
DeclareOpInterfaceMethods<ConditionallySpeculatable>
834836
]> {
835837
let summary = "inline assembly applying an elementwise operation to a group of packed elements.";
836838
let description = [{

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
225225

226226
def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
227227
TransposeOpInterface,
228-
DeclareOpInterfaceMethods<InferTypeOpInterface>,
228+
InferTypeOpWithLayoutEquivalence,
229229
SameOperandsAndResultElementType]> {
230230
let summary = "transpose the descriptor";
231231

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,23 @@ OpFoldResult TransOp::fold(FoldAdaptor adaptor) {
209209
return {};
210210
}
211211

212+
LogicalResult TransOp::verify() {
213+
auto order = getOrder();
214+
auto srcTy = cast<RankedTensorType>(getSrc().getType());
215+
if (order.size() != srcTy.getShape().size()) {
216+
return emitError("order must have the same size as the source tensor");
217+
}
218+
if (!isPermutationOfIota(order)) {
219+
return emitError("order must be a permutation of 0..n-1");
220+
}
221+
SmallVector<int64_t> retShape = applyPermutation(srcTy.getShape(), order);
222+
if (retShape != getType().getShape()) {
223+
return emitError(
224+
"result shape must match the permutation of the source shape");
225+
}
226+
return success();
227+
}
228+
212229
LogicalResult TransOp::inferReturnTypes(
213230
MLIRContext *context, std::optional<Location> location,
214231
TransOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
@@ -1037,6 +1054,12 @@ void ElementwiseInlineAsmOp::getEffects(
10371054
SideEffects::DefaultResource::get());
10381055
}
10391056

1057+
Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() {
1058+
if (getPure())
1059+
return Speculation::Speculatable;
1060+
return Speculation::NotSpeculatable;
1061+
}
1062+
10401063
LogicalResult ElementwiseInlineAsmOp::verify() {
10411064
if (getNumOperands() >= 1) {
10421065
auto tensorType = dyn_cast<RankedTensorType>(getOperand(0).getType());

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -454,15 +454,17 @@ OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {
454454
return {};
455455
}
456456

457-
LogicalResult MemDescTransOp::inferReturnTypes(
458-
MLIRContext *context, std::optional<Location> location, ValueRange operands,
459-
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
460-
SmallVectorImpl<Type> &inferredReturnTypes) {
457+
LogicalResult
458+
MemDescTransOp::inferReturnTypes(MLIRContext *context,
459+
std::optional<Location> location,
460+
MemDescTransOp::Adaptor adaptor,
461+
SmallVectorImpl<Type> &inferredReturnTypes) {
462+
461463
// type is the same as the input
462-
auto argTy = cast<MemDescType>(operands[0].getType());
463-
auto argShape = argTy.getShape();
464-
auto order = properties.as<Properties *>()->order.asArrayRef();
465-
SmallVector<int64_t> retShape = applyPermutation(argTy.getShape(), order);
464+
auto argTy = cast<MemDescType>(adaptor.getSrc().getType());
465+
auto shape = argTy.getShape();
466+
auto order = adaptor.getOrder();
467+
SmallVector<int64_t> retShape = applyPermutation(shape, order);
466468

467469
auto retEltTy = argTy.getElementType();
468470
Attribute argEncoding = argTy.getEncoding();
@@ -471,17 +473,17 @@ LogicalResult MemDescTransOp::inferReturnTypes(
471473
Dialect &dialect = argEncoding.getDialect();
472474
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
473475
if (inferLayoutInterface
474-
->inferTransOpEncoding(argEncoding, argShape, order, retEncoding)
476+
->inferTransOpEncoding(argEncoding, shape, order, retEncoding)
475477
.failed()) {
476478
return failure();
477479
}
478480
}
479-
auto memDescTy = cast<MemDescType>(argTy);
480-
inferredReturnTypes.push_back(MemDescType::get(
481-
retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(),
482-
memDescTy.getMutableMemory()));
481+
inferredReturnTypes.push_back(
482+
MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(),
483+
argTy.getMutableMemory()));
483484
return success();
484485
}
486+
485487
// LocalAllocOp
486488
void LocalAllocOp::getEffects(
487489
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 0 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -24,36 +24,6 @@ namespace {
2424
// Roughly, whether op is elementwise and thus threads don't need
2525
// to exchange elements. But some ops are not currently supported even though
2626
// they meet that criterion.
27-
bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) {
28-
// Only consider custom conversions or arith ops.
29-
// TODO(jlebar): Is this too restrictive?
30-
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) &&
31-
!isa<arith::ArithDialect>(op->getDialect()))
32-
return false;
33-
34-
// Quick handling to fix loading issues when computing the original
35-
// bitwidth is unable to realize that there is a mixed-precision dot
36-
// (hence kWidth = 1) but wants to hoist through the type conversion.
37-
if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth() == 1)
38-
return false;
39-
40-
// Currently, these instructions are not supported during lowering of
41-
// shared -> dot_operand layout. Not all types and type conversions are
42-
// supported.
43-
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
44-
return false;
45-
46-
// Don't hoist through u1 -> fp casts as they aren't supported in
47-
// ElementwiseOpToLLVM::reorderValues().
48-
if (isa<arith::UIToFPOp>(op)) {
49-
Type opType = getElementTypeOrSelf(op->getOperand(0));
50-
if (opType.isInteger(1))
51-
return false;
52-
}
53-
54-
return true;
55-
}
56-
5727
// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
5828
// is in registers).
5929
bool canHoistDotOpEncV3(Operation *op) {
@@ -198,116 +168,6 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
198168
}
199169
};
200170

201-
// Move convert-to-dot-operand "up" past elementwise ops:
202-
//
203-
// convert(elementwise(x)) #dot_operand ->
204-
// elementwise(convert(x, #dot_operand)).
205-
//
206-
// The goal is to put the convert right next to the originating load. If we can
207-
// accomplish this, then we can save a shmem round-trip:
208-
//
209-
// Before:
210-
//
211-
// - Load from global into shmem using an async copy.
212-
// - Load from shmem into a #blocked layout.
213-
// - Do elementwise ops over #blocked layout.
214-
// - Convert to #dot_operand (round-trip through shmem).
215-
// - Do dot.
216-
//
217-
// After:
218-
//
219-
// - Load from global into shmem using an async copy (same as before).
220-
// - Load from shmem into a #dot_operand layout.
221-
// - Do elementwise ops over #dot_operand layout.
222-
// - Do dot.
223-
//
224-
// This can also be propagated when we have a constant, instead of a load.
225-
//
226-
// Eliminating the shmem round-trip is such a big win, we're willing to do it
227-
// even if this duplicates work because some of the elementwise ops have uses
228-
// that don't flow into the dot. On the other hand, we only want to do this if
229-
// we can in fact reduce shmem round-trips: For example, simply moving a convert
230-
// up above e.g. an `add` now means we have *two* converts. That's worse,
231-
// unless we can continue moving the converts upwards and eventually merge them.
232-
// So we try to check that this will be beneficial before making any changes.
233-
class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
234-
public:
235-
using OpRewritePattern::OpRewritePattern;
236-
237-
LogicalResult matchAndRewrite(ConvertLayoutOp cvt,
238-
PatternRewriter &rewriter) const override {
239-
// Only consider conversions to dot operand.
240-
auto cvtTy = cast<RankedTensorType>(cvt.getType());
241-
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
242-
if (!dotOpEnc)
243-
return failure();
244-
245-
auto src = cvt.getSrc().getDefiningOp();
246-
if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1)
247-
return failure();
248-
249-
auto srcTy = dyn_cast<RankedTensorType>(src->getResult(0).getType());
250-
if (!srcTy)
251-
return failure();
252-
253-
if (!all_of(src->getOperandTypes(),
254-
[](Type ty) { return isa<RankedTensorType>(ty); }))
255-
return failure();
256-
257-
if (!canHoistDotOpEncV2(src, dotOpEnc))
258-
return failure();
259-
260-
// Check that the conversion is transitively dependent on a load or a
261-
// constant, and all operations between it and the convert are layout
262-
// preserving.
263-
//
264-
// TODO(jlebar): This is accidentally quadratic; we iterate over the whole
265-
// slice but then at the end we only modify one op!
266-
SetVector<Operation *> slice;
267-
BackwardSliceOptions opt;
268-
opt.omitBlockArguments = true;
269-
getBackwardSlice(cvt.getOperation(), &slice, opt);
270-
271-
// TODO(jlebar): This is too conservative when there are multiple loads in
272-
// the chain. If one of the loads has a non-layout-preserving op and the
273-
// other does not, then we may or may not accept the chain, depending on
274-
// which load gets hit first by getBackwardSlice. For example:
275-
// cvt(broadcast(load(x)) + load(y)) // accepted & load(y) will benefit.
276-
// cvt(load(y) + broadcast(load(x))) // rejected & load(y) will not benefit.
277-
bool foundInitializer = false;
278-
// Reverse the slice so that we start directly above the convert and check
279-
// that every op allows hoisting until we find a load or a constant.
280-
for (Operation *currOp : llvm::reverse(slice)) {
281-
if (isa<LoadOp>(currOp) || isa<arith::ConstantOp>(currOp)) {
282-
foundInitializer = true;
283-
break;
284-
}
285-
if (!canHoistDotOpEncV2(currOp, dotOpEnc))
286-
return failure();
287-
}
288-
if (!foundInitializer)
289-
return failure();
290-
291-
SmallVector<ConvertLayoutOp> newOperands;
292-
for (auto operand : src->getOperands()) {
293-
// We checked earlier that all operands are ranked tensors.
294-
auto operandTy = cast<RankedTensorType>(operand.getType());
295-
Type newCvtTy = RankedTensorType::get(
296-
srcTy.getShape(), operandTy.getElementType(), cvtTy.getEncoding());
297-
newOperands.push_back(
298-
rewriter.create<ConvertLayoutOp>(cvt.getLoc(), newCvtTy, operand));
299-
}
300-
auto newRet = rewriter.clone(*src);
301-
for (int i = 0; i < newOperands.size(); i++)
302-
newRet->setOperand(i, newOperands[i]);
303-
newRet->getResult(0).setType(RankedTensorType::get(
304-
srcTy.getShape(), srcTy.getElementType(), cvtTy.getEncoding()));
305-
306-
rewriter.replaceOp(cvt, newRet->getResults());
307-
return success();
308-
}
309-
};
310-
311171
// Rewrite
312172
//
313173
// dot(alloc(trans() #shared1) ->
@@ -702,8 +562,6 @@ class TritonGPUOptimizeDotOperandsPass
702562
mlir::RewritePatternSet patterns(context);
703563
patterns.add<MMAV3HoistLayoutConversion>(context);
704564
patterns.add<SwizzleShmemConvert>(context);
705-
if (this->hoistLayoutConversion.getValue())
706-
patterns.add<HoistLayoutConversion>(context);
707565
patterns.add<FuseTransMMAV3Plus>(context);
708566
patterns.add<MMAV3UseRegOperand>(context);
709567
patterns.add<InjectTMemCopy>(context);

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3232
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
3333
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
34+
#include "llvm/Support/Debug.h"
35+
36+
#define DEBUG_TYPE "tritongpu-prefetch"
37+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
38+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
3439

3540
namespace mlir {
3641
namespace triton {
@@ -186,17 +191,23 @@ LogicalResult Prefetcher::initialize() {
186191
bool foundConvertFromShared = false;
187192
SmallVector<Value> rets;
188193
rets.push_back(op->getResult(0));
194+
LDBG("Prefetch src: " << *op);
189195
while (op) {
190196
if (op->getNumOperands() != 1)
191197
break;
192198
if (!op->getResult(0).hasOneUse())
193199
break;
194200
rets.push_back(op->getOperand(0));
195201
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
196-
foundConvertFromShared = true;
202+
// NYI for other encodings, for example if we have transpose
203+
// in the chain
204+
if (isa<DotOperandEncodingAttr>(cvt.getType().getEncoding()))
205+
foundConvertFromShared = true;
197206
break;
198207
}
199208
op = op->getOperand(0).getDefiningOp();
209+
if (op)
210+
LDBG("op: " << *op);
200211
}
201212
std::reverse(rets.begin(), rets.end());
202213

0 commit comments

Comments
 (0)