@@ -24,36 +24,6 @@ namespace {
2424// Roughly, whether op is elementwise and thus threads don't need
2525// to exchange elements. But some ops are not currently supported even though
2626// they meet that criterion.
27- bool canHoistDotOpEncV2 (Operation *op, DotOperandEncodingAttr &dotOpEnc) {
28- // Only consider custom conversions or arith ops.
29- // TODO(jlebar): Is this too restrictive?
30- if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm (op) &&
31- !isa<arith::ArithDialect>(op->getDialect ()))
32- return false ;
33-
34- // Quick handling to fix loading issues when computing the original
35- // bitwidth is unable to realize that there is a mixed-precision dot
36- // (hence kWidth = 1) but wants to hoist through the type conversion.
37- if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth () == 1 )
38- return false ;
39-
40- // Currently, these instructions are not supported during lowering of
41- // shared -> dot_operand layout. Not all types and type conversions are
42- // supported.
43- if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
44- return false ;
45-
46- // Don't hoist through u1 -> fp casts as they aren't supported in
47- // ElementwiseOpToLLVM::reorderValues().
48- if (isa<arith::UIToFPOp>(op)) {
49- Type opType = getElementTypeOrSelf (op->getOperand (0 ));
50- if (opType.isInteger (1 ))
51- return false ;
52- }
53-
54- return true ;
55- }
56-
5727// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
5828// is in registers).
5929bool canHoistDotOpEncV3 (Operation *op) {
@@ -198,116 +168,6 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
198168 }
199169};
200170
201- // Move convert-to-dot-operand "up" past elementwise ops:
202- //
203- // convert(elementwise(x)) #dot_operand ->
204- // elementwise(convert(x, #dot_operand)).
205- //
206- // The goal is to put the convert right next to the originating load. If we can
207- // accomplish this, then we can save a shmem round-trip:
208- //
209- // Before:
210- //
211- // - Load from global into shmem using an async copy.
212- // - Load from shmem into a #blocked layout.
213- // - Do elementwise ops over #blocked layout.
214- // - Convert to #dot_operand (round-trip through shmem).
215- // - Do dot.
216- //
217- // After:
218- //
219- // - Load from global into shmem using an async copy (same as before).
220- // - Load from shmem into a #dot_operand layout.
221- // - Do elementwise ops over #dot_operand layout.
222- // - Do dot.
223- //
224- // This can also be propagated when we have a constant, instead of a load.
225- //
226- // Eliminating the shmem round-trip is such a big win, we're willing to do it
227- // even if this duplicates work because some of the elementwise ops have uses
228- // that don't flow into the dot. On the other hand, we only want to do this if
229- // we can in fact reduce shmem round-trips: For example, simply moving a convert
230- // up above e.g. an `add` now means we have *two* converts. That's worse,
231- // unless we can continue moving the converts upwards and eventually merge them.
232- // So we try to check that this will be beneficial before making any changes.
233- class HoistLayoutConversion : public OpRewritePattern <ConvertLayoutOp> {
234- public:
235- using OpRewritePattern::OpRewritePattern;
236-
237- LogicalResult matchAndRewrite (ConvertLayoutOp cvt,
238- PatternRewriter &rewriter) const override {
239- // Only consider conversions to dot operand.
240- auto cvtTy = cast<RankedTensorType>(cvt.getType ());
241- auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding ());
242- if (!dotOpEnc)
243- return failure ();
244-
245- auto src = cvt.getSrc ().getDefiningOp ();
246- if (!src || src->getNumOperands () == 0 || src->getNumResults () != 1 )
247- return failure ();
248-
249- auto srcTy = dyn_cast<RankedTensorType>(src->getResult (0 ).getType ());
250- if (!srcTy)
251- return failure ();
252-
253- if (!all_of (src->getOperandTypes (),
254- [](Type ty) { return isa<RankedTensorType>(ty); }))
255- return failure ();
256-
257- if (!canHoistDotOpEncV2 (src, dotOpEnc))
258- return failure ();
259-
260- // Check that the conversion is transitively dependent on a load or a
261- // constant, and all operations between it and the convert are layout
262- // preserving.
263- //
264- // TODO(jlebar): This is accidentally quadratic; we iterate over the whole
265- // slice but then at the end we only modify one op!
266- SetVector<Operation *> slice;
267- BackwardSliceOptions opt;
268- opt.omitBlockArguments = true ;
269- getBackwardSlice (cvt.getOperation (), &slice, opt);
270-
271- // TODO(jlebar): This is too conservative when there are multiple loads in
272- // the chain. If one of the loads has a non-layout-preserving op and the
273- // other does not, then we may or may not accept the chain, depending on
274- // which load gets hit first by getBackwardSlice. For example:
275- // cvt(broadcast(load(x)) + load(y)) // accepted & load(y) will benefit.
276- // cvt(load(y) + broadcast(load(x))) // rejected & load(y) will not benefit.
277- bool foundInitializer = false ;
278- // Reverse the slice so that we start directly above the convert and check
279- // that every op allows hoisting until we find a load or a constant.
280- for (Operation *currOp : llvm::reverse (slice)) {
281- if (isa<LoadOp>(currOp) || isa<arith::ConstantOp>(currOp)) {
282- foundInitializer = true ;
283- break ;
284- }
285- if (!canHoistDotOpEncV2 (currOp, dotOpEnc))
286- return failure ();
287- }
288- if (!foundInitializer)
289- return failure ();
290-
291- SmallVector<ConvertLayoutOp> newOperands;
292- for (auto operand : src->getOperands ()) {
293- // We checked earlier that all operands are ranked tensors.
294- auto operandTy = cast<RankedTensorType>(operand.getType ());
295- Type newCvtTy = RankedTensorType::get (
296- srcTy.getShape (), operandTy.getElementType (), cvtTy.getEncoding ());
297- newOperands.push_back (
298- rewriter.create <ConvertLayoutOp>(cvt.getLoc (), newCvtTy, operand));
299- }
300- auto newRet = rewriter.clone (*src);
301- for (int i = 0 ; i < newOperands.size (); i++)
302- newRet->setOperand (i, newOperands[i]);
303- newRet->getResult (0 ).setType (RankedTensorType::get (
304- srcTy.getShape (), srcTy.getElementType (), cvtTy.getEncoding ()));
305-
306- rewriter.replaceOp (cvt, newRet->getResults ());
307- return success ();
308- }
309- };
310-
311171// Rewrite
312172//
313173// dot(alloc(trans() #shared1) ->
@@ -702,8 +562,6 @@ class TritonGPUOptimizeDotOperandsPass
702562 mlir::RewritePatternSet patterns (context);
703563 patterns.add <MMAV3HoistLayoutConversion>(context);
704564 patterns.add <SwizzleShmemConvert>(context);
705- if (this ->hoistLayoutConversion .getValue ())
706- patterns.add <HoistLayoutConversion>(context);
707565 patterns.add <FuseTransMMAV3Plus>(context);
708566 patterns.add <MMAV3UseRegOperand>(context);
709567 patterns.add <InjectTMemCopy>(context);
0 commit comments