@@ -131,8 +131,6 @@ class LayoutRematerialization {
131131 void backwardRematerialization (ConvertLayoutOp convertOp);
132132 void hoistConvertOnTopOfExtOrBroadcast ();
133133 void hoistConvertOnTopOfExtOrBroadcast (ConvertLayoutOp convertOp);
134- void hoistConvertIntoConditionals ();
135- void hoistConvertIntoConditionals (ConvertLayoutOp convertOp);
136134 void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
137135 ConvertLayoutOp convertOp, IRMapping &mapping);
138136 void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
@@ -1022,66 +1020,13 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
10221020 }
10231021}
10241022
1025- bool shouldPropagateConversion (ConvertLayoutOp convertOp) {
1026- RankedTensorType targetType = convertOp.getType ();
1027- auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding ());
1028- // If the target encoding is not DotOperandEncodingAttr, allow propagation.
1029- if (!dotEnc) {
1030- return true ;
1031- }
1032- // Skip conversions to DotOperandEncodingAttr when the operand index is 0.
1033- // This heuristic is applied to prevent moving the blocked->dot conversion of
1034- // the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing
1035- // so can increase register pressure and cause spilling in some cases.
1036- if (dotEnc.getOpIdx () == 0 ) {
1037- return false ;
1038- }
1039- // Skip conversions to DotOperandEncodingAttr when the operand index is 1 if
1040- // it's not intentionally placed above a load as we have to be a bit more
1041- // careful with the heuristics for both correctness and performance.
1042- // TODO: Fix this logic to avoid propagating conversions backward unless
1043- // it reduces the total number of conversions.
1044- assert (dotEnc.getOpIdx () == 1 );
1045- SetVector<Operation *> slice;
1046- BackwardSliceOptions opt;
1047- opt.omitBlockArguments = true ;
1048- opt.filter = [&](Operation *op) {
1049- return op->getParentRegion () == convertOp->getParentRegion ();
1050- };
1051- getBackwardSlice (convertOp.getOperation (), &slice, opt);
1052-
1053- for (Operation *currOp : slice) {
1054- if (isa<LoadOp>(currOp)) {
1055- return false ;
1056- }
1057- }
1058- // Allow propagation if no LoadOp is found.
1059- return true ;
1060- }
1061-
1062- void LayoutRematerialization::hoistConvertIntoConditionals () {
1063- // Go through each ConvertLayoutOp.
1064- SmallVector<ConvertLayoutOp> convertOps;
1065- funcOp.walk (
1066- [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
1067- for (ConvertLayoutOp convertOp : convertOps) {
1068- hoistConvertIntoConditionals (convertOp);
1069- if (!opToDelete.contains (convertOp)) {
1070- // If the conversion didn't get removed, consider it for reuse in future
1071- // backward slices.
1072- addRematValue (convertOp.getSrc (), convertOp.getType ().getEncoding (),
1073- convertOp.getResult ());
1074- }
1075- }
1076- }
1077-
10781023void LayoutRematerialization::backwardRematerialization (
10791024 ConvertLayoutOp convertOp) {
1025+ // we don't handle conversions to DotOperandEncodingAttr
1026+ // this is a heuristic to accommodate fused attention
10801027 RankedTensorType targetType = convertOp.getType ();
1081- if (! shouldPropagateConversion (convertOp)) {
1028+ if (isa<DotOperandEncodingAttr>(targetType. getEncoding ()))
10821029 return ;
1083- }
1084-
10851030 Value oldV = convertOp.getSrc ();
10861031 LDBG (" check backward remat with source " << oldV << " encoding "
10871032 << targetType.getEncoding ());
@@ -1120,10 +1065,11 @@ void LayoutRematerialization::backwardRematerialization(
11201065// of the convert.
11211066void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast (
11221067 ConvertLayoutOp convertOp) {
1068+ // we don't handle conversions to DotOperandEncodingAttr
1069+ // this is a heuristics to accommodate fused attention
11231070 RankedTensorType targetType = convertOp.getType ();
1124- if (! shouldPropagateConversion (convertOp)) {
1071+ if (isa<DotOperandEncodingAttr>(targetType. getEncoding ()))
11251072 return ;
1126- }
11271073
11281074 auto isExtOrBroadcastOp = [](Operation *op) {
11291075 if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp, BroadcastOp,
@@ -1205,100 +1151,6 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
12051151 rewriteSlice (slice, layout, convertOp, mapping);
12061152}
12071153
1208- void LayoutRematerialization::hoistConvertIntoConditionals (
1209- ConvertLayoutOp convertOp) {
1210- // Take the backward slice of tensor dependencies, stopping at conditionals.
1211- SetVector<Value> slice;
1212- DenseMap<Value, Attribute> layout;
1213- auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
1214- if (failed (getRematerializableSlice (convertOp.getSrcMutable (),
1215- convertOp.getType ().getEncoding (), slice,
1216- layout, isIfOp)))
1217- return ;
1218-
1219- // Find conditional edges above which the conversion can be hoisted.
1220- SmallVector<std::pair<Value, OpOperand *>> hoistAbove;
1221- unsigned sliceSize = slice.size ();
1222- // The routine will recurse through backward slices, e.g. to handle loops and
1223- // conditional chains. Thus, we re-query the size of `slice`.
1224- for (unsigned i = 0 ; i < slice.size (); i++) {
1225- Value v = slice[i];
1226- auto ifOp = v.getDefiningOp <scf::IfOp>();
1227- if (!ifOp)
1228- continue ;
1229-
1230- Attribute rootLayout = layout.at (v);
1231- unsigned resIdx = cast<OpResult>(v).getResultNumber ();
1232-
1233- // Take the backward slice along each branch.
1234- auto thenYield =
1235- cast<scf::YieldOp>(ifOp.getThenRegion ().front ().getTerminator ());
1236- auto elseYield =
1237- cast<scf::YieldOp>(ifOp.getElseRegion ().front ().getTerminator ());
1238-
1239- OpOperand &thenRes = thenYield.getResultsMutable ()[resIdx];
1240- OpOperand &elseRes = elseYield.getResultsMutable ()[resIdx];
1241-
1242- SetVector<Value> thenSlice, elseSlice;
1243- DenseMap<Value, Attribute> thenLayout, elseLayout;
1244-
1245- LogicalResult thenResult = getRematerializableSlice (
1246- thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
1247- LogicalResult elseResult = getRematerializableSlice (
1248- elseRes, rootLayout, elseSlice, elseLayout, isIfOp);
1249-
1250- // If propagation across both edges of this conditional succeeded, then we
1251- // don't need to hoist across it.
1252- if (succeeded (thenResult) && succeeded (elseResult)) {
1253- slice.insert (thenSlice.begin (), thenSlice.end ());
1254- slice.insert (elseSlice.begin (), elseSlice.end ());
1255- layout.insert (thenLayout.begin (), thenLayout.end ());
1256- layout.insert (elseLayout.begin (), elseLayout.end ());
1257- continue ;
1258- }
1259-
1260- // If propagation across both edges failed, then there is nothing to do
1261- // for this one.
1262- if (failed (thenResult) && failed (elseResult))
1263- continue ;
1264-
1265- // The layout conversion can be rematerialized along one edge but not the
1266- // other. We can hoist the conversion into the other branch.
1267- if (succeeded (elseResult)) {
1268- std::swap (thenSlice, elseSlice);
1269- std::swap (thenLayout, elseLayout);
1270- hoistAbove.push_back ({v, &thenRes});
1271- } else {
1272- hoistAbove.push_back ({v, &elseRes});
1273- }
1274- slice.insert (thenSlice.begin (), thenSlice.end ());
1275- layout.insert (thenLayout.begin (), thenLayout.end ());
1276- }
1277-
1278- // It's hard to know if duplicating the conversion into separate branches is
1279- // profitable without more analysis. For now, hoist at most one.
1280- if (hoistAbove.size () != 1 )
1281- return ;
1282-
1283- IRMapping mapping;
1284- for (auto [result, edge] : hoistAbove) {
1285- // Hoist the convert into the conditional and rewrite the slice.
1286- OpBuilder b (edge->getOwner ());
1287- Value v = edge->get ();
1288- Attribute encoding = layout.at (result);
1289-
1290- auto tensorType = cast<RankedTensorType>(v.getType ());
1291- auto newType = RankedTensorType::get (tensorType.getShape (),
1292- tensorType.getElementType (), encoding);
1293-
1294- Value newCvt = b.create <ConvertLayoutOp>(convertOp.getLoc (), newType, v);
1295-
1296- mapping.map (v, newCvt);
1297- slice.remove (v);
1298- }
1299- rewriteSlice (slice, layout, convertOp, mapping);
1300- }
1301-
13021154void backwardRematerialization (ModuleOp module ) {
13031155 module .walk ([](FuncOp funcOp) {
13041156 LayoutRematerialization layoutRemat (funcOp);
@@ -1313,10 +1165,6 @@ void hoistConvert(ModuleOp module) {
13131165 LayoutRematerialization layoutRemat (funcOp);
13141166 layoutRemat.hoistConvertOnTopOfExtOrBroadcast ();
13151167 layoutRemat.cleanup ();
1316-
1317- layoutRemat = LayoutRematerialization (funcOp);
1318- layoutRemat.hoistConvertIntoConditionals ();
1319- layoutRemat.cleanup ();
13201168 });
13211169}
13221170} // namespace
0 commit comments