@@ -176,6 +176,7 @@ class LayoutRematerialization {
176
176
SetVector<Operation *> opToDelete;
177
177
FuncOp funcOp;
178
178
DominanceInfo domInfo;
179
+ PostDominanceInfo postDomInfo;
179
180
};
180
181
181
182
void LayoutRematerialization::addRematValue (Value old, Attribute encoding,
@@ -525,8 +526,7 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
525
526
return rewrittenValue;
526
527
OpBuilder rewriter (value.getContext ());
527
528
rewriter.setInsertionPointAfterValue (rewrittenValue);
528
- auto tmpType = RankedTensorType::get (tensorType.getShape (),
529
- tensorType.getElementType (), encoding);
529
+ auto tmpType = tensorType.cloneWithEncoding (encoding);
530
530
Value converted = rewriter.create <ConvertLayoutOp>(value.getLoc (), tmpType,
531
531
rewrittenValue);
532
532
// TODO: we could cache the conversion.
@@ -567,8 +567,7 @@ Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter,
567
567
auto origType = dyn_cast<RankedTensorType>(op->getResult (i).getType ());
568
568
if (!origType)
569
569
continue ;
570
- auto newType = RankedTensorType::get (origType.getShape (),
571
- origType.getElementType (), encoding);
570
+ auto newType = origType.cloneWithEncoding (encoding);
572
571
newOp->getResult (i).setType (newType);
573
572
}
574
573
return newOp;
@@ -631,9 +630,7 @@ Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) {
631
630
continue ;
632
631
}
633
632
auto origType = dyn_cast<RankedTensorType>(ret.getType ());
634
- auto newType =
635
- RankedTensorType::get (origType.getShape (), origType.getElementType (),
636
- it->second .encodings [0 ]);
633
+ auto newType = origType.cloneWithEncoding (it->second .encodings [0 ]);
637
634
returnTypes.push_back (newType);
638
635
}
639
636
@@ -683,8 +680,7 @@ Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) {
683
680
continue ;
684
681
auto origType = cast<RankedTensorType>(ifOp->getResult (i).getType ());
685
682
Attribute encoding = *(it->second .encodings .begin ());
686
- newResultTypes[i] = RankedTensorType::get (
687
- origType.getShape (), origType.getElementType (), encoding);
683
+ newResultTypes[i] = origType.cloneWithEncoding (encoding);
688
684
}
689
685
auto newIfOp = rewriter.create <scf::IfOp>(ifOp.getLoc (), newResultTypes,
690
686
ifOp.getCondition (), true , true );
@@ -940,17 +936,15 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
940
936
srcEncoding = *(it->second .encodings .begin ());
941
937
Value src = getValueAs (convertOp.getSrc (), srcEncoding);
942
938
auto tensorType = cast<RankedTensorType>(op->getResult (0 ).getType ());
943
- auto newType = RankedTensorType::get (tensorType.getShape (),
944
- tensorType.getElementType (), encoding);
939
+ auto newType = tensorType.cloneWithEncoding (encoding);
945
940
auto cvt = rewriter.create <ConvertLayoutOp>(op->getLoc (), newType, src);
946
941
map (op->getResult (0 ), cvt.getResult ());
947
942
return cvt.getOperation ();
948
943
}
949
944
if (canFoldIntoConversion (op, encoding)) {
950
945
Operation *newOp = rewriter.clone (*op);
951
946
auto tensorType = cast<RankedTensorType>(op->getResult (0 ).getType ());
952
- auto newType = RankedTensorType::get (tensorType.getShape (),
953
- tensorType.getElementType (), encoding);
947
+ auto newType = tensorType.cloneWithEncoding (encoding);
954
948
auto cvt = rewriter.create <ConvertLayoutOp>(op->getLoc (), newType,
955
949
newOp->getResult (0 ));
956
950
map (op->getResult (0 ), cvt.getResult ());
@@ -1111,14 +1105,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
1111
1105
1112
1106
Type resType = res.getType ();
1113
1107
if (auto oldType = dyn_cast<RankedTensorType>(resType)) {
1114
- auto newType = RankedTensorType::get (
1115
- oldType.getShape (), oldType.getElementType (), it->second );
1108
+ Type newType = oldType.cloneWithEncoding (it->second );
1116
1109
newTypes.push_back (newType);
1117
1110
} else if (auto ptrType = dyn_cast<PointerType>(resType)) {
1118
1111
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
1119
- auto newType = triton::PointerType::get (
1120
- RankedTensorType::get (tensorType.getShape (),
1121
- tensorType.getElementType (), it->second ),
1112
+ Type newType = triton::PointerType::get (
1113
+ tensorType.cloneWithEncoding (it->second ),
1122
1114
ptrType.getAddressSpace ());
1123
1115
newTypes.push_back (newType);
1124
1116
}
@@ -1158,9 +1150,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
1158
1150
if (isa<arith::ConstantOp>(op)) {
1159
1151
Operation *newOp = builder.clone (*op);
1160
1152
auto tensorType = cast<RankedTensorType>(op->getResult (0 ).getType ());
1161
- auto newType = RankedTensorType::get (tensorType.getShape (),
1162
- tensorType.getElementType (),
1163
- layout[op->getResult (0 )]);
1153
+ auto newType = tensorType.cloneWithEncoding (layout[op->getResult (0 )]);
1164
1154
auto cvt = builder.create <ConvertLayoutOp>(op->getLoc (), newType,
1165
1155
newOp->getResult (0 ));
1166
1156
mapping.map (op->getResult (0 ), cvt.getResult ());
@@ -1178,14 +1168,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
1178
1168
if (isTensorPointerType (oldType)) {
1179
1169
auto ptrType = cast<PointerType>(oldType);
1180
1170
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
1181
- newType = triton::PointerType::get (
1182
- RankedTensorType::get (tensorType.getShape (),
1183
- tensorType.getElementType (), it->second ),
1184
- ptrType.getAddressSpace ());
1171
+ newType =
1172
+ triton::PointerType::get (tensorType.cloneWithEncoding (it->second ),
1173
+ ptrType.getAddressSpace ());
1185
1174
} else {
1186
- newType = RankedTensorType::get (
1187
- cast<RankedTensorType>(old.getType ()).getShape (),
1188
- cast<RankedTensorType>(old.getType ()).getElementType (), it->second );
1175
+ newType =
1176
+ cast<RankedTensorType>(old.getType ()).cloneWithEncoding (it->second );
1189
1177
}
1190
1178
newV.setType (newType);
1191
1179
addRematValue (old, it->second , newV);
@@ -1522,13 +1510,40 @@ void LayoutRematerialization::hoistConvertDotOperand() {
1522
1510
void LayoutRematerialization::hoistConvertDotOperand (
1523
1511
ConvertLayoutOp convertOp) {
1524
1512
auto targetType = convertOp.getType ();
1525
- // The pass is targeted to Nvidia mma/wgmma dot operands
1513
+ // The pass is targeted to MMA dot operands
1514
+
1515
+ auto canBePipelined = [&](ConvertLayoutOp convertOp) {
1516
+ // FIXME: Check that the parent is a for loop
1517
+ auto parent = convertOp->getParentOp ();
1518
+ if (!parent)
1519
+ return false ;
1520
+
1521
+ // Find all the dot-like ops in the for loop that have a dot operand
1522
+ // encoding on the lhs and check if any of them post-dominates the load +
1523
+ // cvt
1524
+ SmallVector<Operation *> dotLikeOps;
1525
+ parent->walk ([&](Operation *op) {
1526
+ if (!isa<mlir::triton::DotOpInterface>(op))
1527
+ return ;
1528
+ auto opType = dyn_cast<RankedTensorType>(op->getOperand (0 ).getType ());
1529
+ if (!opType)
1530
+ return ;
1531
+ auto dotEnc = dyn_cast<DotOperandEncodingAttr>(opType.getEncoding ());
1532
+ if (!dotEnc)
1533
+ return ;
1534
+ if (isa<MmaEncodingTrait>(dotEnc.getParent ()))
1535
+ dotLikeOps.push_back (op);
1536
+ });
1537
+ if (dotLikeOps.empty ())
1538
+ return false ;
1539
+ return llvm::any_of (dotLikeOps, [&](Operation *dot) {
1540
+ return postDomInfo.postDominates (dot, convertOp);
1541
+ });
1542
+ };
1543
+
1526
1544
// We move convert #dot_operand next to their loads. This is done
1527
1545
// so that it's then easy to pipeline these loads
1528
- // TODO: Perhaps we should do this whenever convertOp is within a loop
1529
-
1530
- auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding ());
1531
- if (!(dotEnc && isa<NvidiaMmaEncodingAttr>(dotEnc.getParent ())))
1546
+ if (!canBePipelined (convertOp))
1532
1547
return ;
1533
1548
1534
1549
// We hoist over any operation that can be done without data movement between
@@ -1579,8 +1594,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
1579
1594
auto type = dyn_cast<RankedTensorType>(loadOp->getResult (0 ).getType ());
1580
1595
if (!type)
1581
1596
continue ;
1582
- auto newType = RankedTensorType::get (type.getShape (), type.getElementType (),
1583
- layout[loadOp->getResult (0 )]);
1597
+ auto newType = type.cloneWithEncoding (layout[loadOp->getResult (0 )]);
1584
1598
auto newConvertOp = builder.create <ConvertLayoutOp>(
1585
1599
convertOp.getLoc (), newType, loadOp->getResult (0 ));
1586
1600
mapping.map (loadOp->getResult (0 ), newConvertOp.getResult ());
@@ -1682,18 +1696,16 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
1682
1696
OpBuilder builder (extOrBroadcastOp);
1683
1697
auto tensorType =
1684
1698
cast<RankedTensorType>(extOrBroadcastOp->getOperand (0 ).getType ());
1685
- auto newType = RankedTensorType::get (
1686
- tensorType.getShape (), tensorType.getElementType (), srcEncoding);
1699
+ auto newType = tensorType.cloneWithEncoding (srcEncoding);
1687
1700
auto newConvertOp = builder.create <ConvertLayoutOp>(
1688
1701
convertOp.getLoc (), newType, extOrBroadcastOp->getOperand (0 ));
1689
1702
Operation *newExtOrBroadcast = builder.clone (*extOrBroadcastOp);
1690
1703
newExtOrBroadcast->setOperand (0 , newConvertOp.getResult ());
1691
1704
auto oldExtOrBroadcastType =
1692
1705
cast<RankedTensorType>(extOrBroadcastOp->getResult (0 ).getType ());
1693
- Type newExtOrBroadcasrType = RankedTensorType::get (
1694
- oldExtOrBroadcastType.getShape (), oldExtOrBroadcastType.getElementType (),
1695
- dstEncoding);
1696
- newExtOrBroadcast->getResult (0 ).setType (newExtOrBroadcasrType);
1706
+ Type newExtOrBroadcastType =
1707
+ oldExtOrBroadcastType.cloneWithEncoding (dstEncoding);
1708
+ newExtOrBroadcast->getResult (0 ).setType (newExtOrBroadcastType);
1697
1709
IRMapping mapping;
1698
1710
mapping.map (extOrBroadcastOp->getResult (0 ), newExtOrBroadcast->getResult (0 ));
1699
1711
slice.remove (extOrBroadcastOp->getResult (0 ));
@@ -1798,8 +1810,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
1798
1810
IRMapping mapping;
1799
1811
auto hoistRemat = [&](OpBuilder &b, Value v, Attribute encoding) {
1800
1812
auto tensorType = cast<RankedTensorType>(v.getType ());
1801
- auto newType = RankedTensorType::get (tensorType.getShape (),
1802
- tensorType.getElementType (), encoding);
1813
+ auto newType = tensorType.cloneWithEncoding (encoding);
1803
1814
Value newCvt = b.create <ConvertLayoutOp>(convertOp.getLoc (), newType, v);
1804
1815
1805
1816
mapping.map (v, newCvt);
0 commit comments