@@ -176,6 +176,7 @@ class LayoutRematerialization {
176176 SetVector<Operation *> opToDelete;
177177 FuncOp funcOp;
178178 DominanceInfo domInfo;
179+ PostDominanceInfo postDomInfo;
179180};
180181
181182void LayoutRematerialization::addRematValue (Value old, Attribute encoding,
@@ -525,8 +526,7 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
525526 return rewrittenValue;
526527 OpBuilder rewriter (value.getContext ());
527528 rewriter.setInsertionPointAfterValue (rewrittenValue);
528- auto tmpType = RankedTensorType::get (tensorType.getShape (),
529- tensorType.getElementType (), encoding);
529+ auto tmpType = tensorType.cloneWithEncoding (encoding);
530530 Value converted = rewriter.create <ConvertLayoutOp>(value.getLoc (), tmpType,
531531 rewrittenValue);
532532 // TODO: we could cache the conversion.
@@ -567,8 +567,7 @@ Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter,
567567 auto origType = dyn_cast<RankedTensorType>(op->getResult (i).getType ());
568568 if (!origType)
569569 continue ;
570- auto newType = RankedTensorType::get (origType.getShape (),
571- origType.getElementType (), encoding);
570+ auto newType = origType.cloneWithEncoding (encoding);
572571 newOp->getResult (i).setType (newType);
573572 }
574573 return newOp;
@@ -631,9 +630,7 @@ Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) {
631630 continue ;
632631 }
633632 auto origType = dyn_cast<RankedTensorType>(ret.getType ());
634- auto newType =
635- RankedTensorType::get (origType.getShape (), origType.getElementType (),
636- it->second .encodings [0 ]);
633+ auto newType = origType.cloneWithEncoding (it->second .encodings [0 ]);
637634 returnTypes.push_back (newType);
638635 }
639636
@@ -683,8 +680,7 @@ Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) {
683680 continue ;
684681 auto origType = cast<RankedTensorType>(ifOp->getResult (i).getType ());
685682 Attribute encoding = *(it->second .encodings .begin ());
686- newResultTypes[i] = RankedTensorType::get (
687- origType.getShape (), origType.getElementType (), encoding);
683+ newResultTypes[i] = origType.cloneWithEncoding (encoding);
688684 }
689685 auto newIfOp = rewriter.create <scf::IfOp>(ifOp.getLoc (), newResultTypes,
690686 ifOp.getCondition (), true , true );
@@ -940,17 +936,15 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
940936 srcEncoding = *(it->second .encodings .begin ());
941937 Value src = getValueAs (convertOp.getSrc (), srcEncoding);
942938 auto tensorType = cast<RankedTensorType>(op->getResult (0 ).getType ());
943- auto newType = RankedTensorType::get (tensorType.getShape (),
944- tensorType.getElementType (), encoding);
939+ auto newType = tensorType.cloneWithEncoding (encoding);
945940 auto cvt = rewriter.create <ConvertLayoutOp>(op->getLoc (), newType, src);
946941 map (op->getResult (0 ), cvt.getResult ());
947942 return cvt.getOperation ();
948943 }
949944 if (canFoldIntoConversion (op, encoding)) {
950945 Operation *newOp = rewriter.clone (*op);
951946 auto tensorType = cast<RankedTensorType>(op->getResult (0 ).getType ());
952- auto newType = RankedTensorType::get (tensorType.getShape (),
953- tensorType.getElementType (), encoding);
947+ auto newType = tensorType.cloneWithEncoding (encoding);
954948 auto cvt = rewriter.create <ConvertLayoutOp>(op->getLoc (), newType,
955949 newOp->getResult (0 ));
956950 map (op->getResult (0 ), cvt.getResult ());
@@ -1111,14 +1105,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11111105
11121106 Type resType = res.getType ();
11131107 if (auto oldType = dyn_cast<RankedTensorType>(resType)) {
1114- auto newType = RankedTensorType::get (
1115- oldType.getShape (), oldType.getElementType (), it->second );
1108+ Type newType = oldType.cloneWithEncoding (it->second );
11161109 newTypes.push_back (newType);
11171110 } else if (auto ptrType = dyn_cast<PointerType>(resType)) {
11181111 auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
1119- auto newType = triton::PointerType::get (
1120- RankedTensorType::get (tensorType.getShape (),
1121- tensorType.getElementType (), it->second ),
1112+ Type newType = triton::PointerType::get (
1113+ tensorType.cloneWithEncoding (it->second ),
11221114 ptrType.getAddressSpace ());
11231115 newTypes.push_back (newType);
11241116 }
@@ -1158,9 +1150,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11581150 if (isa<arith::ConstantOp>(op)) {
11591151 Operation *newOp = builder.clone (*op);
11601152 auto tensorType = cast<RankedTensorType>(op->getResult (0 ).getType ());
1161- auto newType = RankedTensorType::get (tensorType.getShape (),
1162- tensorType.getElementType (),
1163- layout[op->getResult (0 )]);
1153+ auto newType = tensorType.cloneWithEncoding (layout[op->getResult (0 )]);
11641154 auto cvt = builder.create <ConvertLayoutOp>(op->getLoc (), newType,
11651155 newOp->getResult (0 ));
11661156 mapping.map (op->getResult (0 ), cvt.getResult ());
@@ -1178,14 +1168,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11781168 if (isTensorPointerType (oldType)) {
11791169 auto ptrType = cast<PointerType>(oldType);
11801170 auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
1181- newType = triton::PointerType::get (
1182- RankedTensorType::get (tensorType.getShape (),
1183- tensorType.getElementType (), it->second ),
1184- ptrType.getAddressSpace ());
1171+ newType =
1172+ triton::PointerType::get (tensorType.cloneWithEncoding (it->second ),
1173+ ptrType.getAddressSpace ());
11851174 } else {
1186- newType = RankedTensorType::get (
1187- cast<RankedTensorType>(old.getType ()).getShape (),
1188- cast<RankedTensorType>(old.getType ()).getElementType (), it->second );
1175+ newType =
1176+ cast<RankedTensorType>(old.getType ()).cloneWithEncoding (it->second );
11891177 }
11901178 newV.setType (newType);
11911179 addRematValue (old, it->second , newV);
@@ -1522,13 +1510,40 @@ void LayoutRematerialization::hoistConvertDotOperand() {
15221510void LayoutRematerialization::hoistConvertDotOperand (
15231511 ConvertLayoutOp convertOp) {
15241512 auto targetType = convertOp.getType ();
1525- // The pass is targeted to Nvidia mma/wgmma dot operands
1513+ // The pass is targeted to MMA dot operands
1514+
1515+ auto canBePipelined = [&](ConvertLayoutOp convertOp) {
1516+ // FIXME: Check that the parent is a for loop
1517+ auto parent = convertOp->getParentOp ();
1518+ if (!parent)
1519+ return false ;
1520+
1521+ // Find all the dot-like ops in the for loop that have a dot operand
1522+ // encoding on the lhs and check if any of them post-dominates the load +
1523+ // cvt
1524+ SmallVector<Operation *> dotLikeOps;
1525+ parent->walk ([&](Operation *op) {
1526+ if (!isa<mlir::triton::DotOpInterface>(op))
1527+ return ;
1528+ auto opType = dyn_cast<RankedTensorType>(op->getOperand (0 ).getType ());
1529+ if (!opType)
1530+ return ;
1531+ auto dotEnc = dyn_cast<DotOperandEncodingAttr>(opType.getEncoding ());
1532+ if (!dotEnc)
1533+ return ;
1534+ if (isa<MmaEncodingTrait>(dotEnc.getParent ()))
1535+ dotLikeOps.push_back (op);
1536+ });
1537+ if (dotLikeOps.empty ())
1538+ return false ;
1539+ return llvm::any_of (dotLikeOps, [&](Operation *dot) {
1540+ return postDomInfo.postDominates (dot, convertOp);
1541+ });
1542+ };
1543+
15261544 // We move convert #dot_operand next to their loads. This is done
15271545 // so that it's then easy to pipeline these loads
1528- // TODO: Perhaps we should do this whenever convertOp is within a loop
1529-
1530- auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding ());
1531- if (!(dotEnc && isa<NvidiaMmaEncodingAttr>(dotEnc.getParent ())))
1546+ if (!canBePipelined (convertOp))
15321547 return ;
15331548
15341549 // We hoist over any operation that can be done without data movement between
@@ -1579,8 +1594,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
15791594 auto type = dyn_cast<RankedTensorType>(loadOp->getResult (0 ).getType ());
15801595 if (!type)
15811596 continue ;
1582- auto newType = RankedTensorType::get (type.getShape (), type.getElementType (),
1583- layout[loadOp->getResult (0 )]);
1597+ auto newType = type.cloneWithEncoding (layout[loadOp->getResult (0 )]);
15841598 auto newConvertOp = builder.create <ConvertLayoutOp>(
15851599 convertOp.getLoc (), newType, loadOp->getResult (0 ));
15861600 mapping.map (loadOp->getResult (0 ), newConvertOp.getResult ());
@@ -1682,18 +1696,16 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
16821696 OpBuilder builder (extOrBroadcastOp);
16831697 auto tensorType =
16841698 cast<RankedTensorType>(extOrBroadcastOp->getOperand (0 ).getType ());
1685- auto newType = RankedTensorType::get (
1686- tensorType.getShape (), tensorType.getElementType (), srcEncoding);
1699+ auto newType = tensorType.cloneWithEncoding (srcEncoding);
16871700 auto newConvertOp = builder.create <ConvertLayoutOp>(
16881701 convertOp.getLoc (), newType, extOrBroadcastOp->getOperand (0 ));
16891702 Operation *newExtOrBroadcast = builder.clone (*extOrBroadcastOp);
16901703 newExtOrBroadcast->setOperand (0 , newConvertOp.getResult ());
16911704 auto oldExtOrBroadcastType =
16921705 cast<RankedTensorType>(extOrBroadcastOp->getResult (0 ).getType ());
1693- Type newExtOrBroadcasrType = RankedTensorType::get (
1694- oldExtOrBroadcastType.getShape (), oldExtOrBroadcastType.getElementType (),
1695- dstEncoding);
1696- newExtOrBroadcast->getResult (0 ).setType (newExtOrBroadcasrType);
1706+ Type newExtOrBroadcastType =
1707+ oldExtOrBroadcastType.cloneWithEncoding (dstEncoding);
1708+ newExtOrBroadcast->getResult (0 ).setType (newExtOrBroadcastType);
16971709 IRMapping mapping;
16981710 mapping.map (extOrBroadcastOp->getResult (0 ), newExtOrBroadcast->getResult (0 ));
16991711 slice.remove (extOrBroadcastOp->getResult (0 ));
@@ -1798,8 +1810,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
17981810 IRMapping mapping;
17991811 auto hoistRemat = [&](OpBuilder &b, Value v, Attribute encoding) {
18001812 auto tensorType = cast<RankedTensorType>(v.getType ());
1801- auto newType = RankedTensorType::get (tensorType.getShape (),
1802- tensorType.getElementType (), encoding);
1813+ auto newType = tensorType.cloneWithEncoding (encoding);
18031814 Value newCvt = b.create <ConvertLayoutOp>(convertOp.getLoc (), newType, v);
18041815
18051816 mapping.map (v, newCvt);
0 commit comments