@@ -1022,43 +1022,6 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
10221022 }
10231023}
10241024
1025- bool shouldPropagateConversion (ConvertLayoutOp convertOp) {
1026- RankedTensorType targetType = convertOp.getType ();
1027- auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding ());
1028- // If the target encoding is not DotOperandEncodingAttr, allow propagation.
1029- if (!dotEnc) {
1030- return true ;
1031- }
1032- // Skip conversions to DotOperandEncodingAttr when the operand index is 0.
1033- // This heuristic is applied to prevent moving the blocked->dot conversion of
1034- // the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing
1035- // so can increase register pressure and cause spilling in some cases.
1036- if (dotEnc.getOpIdx () == 0 ) {
1037- return false ;
1038- }
1039- // Skip conversions to DotOperandEncodingAttr when the operand index is 1 if
1040- // it's not intentionally placed above a load as we have to be a bit more
1041- // careful with the heuristics for both correctness and performance.
1042- // TODO: Fix this logic to avoid propagating conversions backward unless
1043- // it reduces the total number of conversions.
1044- assert (dotEnc.getOpIdx () == 1 );
1045- SetVector<Operation *> slice;
1046- BackwardSliceOptions opt;
1047- opt.omitBlockArguments = true ;
1048- opt.filter = [&](Operation *op) {
1049- return op->getParentRegion () == convertOp->getParentRegion ();
1050- };
1051- getBackwardSlice (convertOp.getOperation (), &slice, opt);
1052-
1053- for (Operation *currOp : slice) {
1054- if (isa<LoadOp>(currOp)) {
1055- return false ;
1056- }
1057- }
1058- // Allow propagation if no LoadOp is found.
1059- return true ;
1060- }
1061-
10621025void LayoutRematerialization::hoistConvertIntoConditionals () {
10631026 // Go through each ConvertLayoutOp.
10641027 SmallVector<ConvertLayoutOp> convertOps;
@@ -1077,11 +1040,11 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
10771040
10781041void LayoutRematerialization::backwardRematerialization (
10791042 ConvertLayoutOp convertOp) {
1043+ // we don't handle conversions to DotOperandEncodingAttr
1044+ // this is a heuristic to accommodate fused attention
10801045 RankedTensorType targetType = convertOp.getType ();
1081- if (! shouldPropagateConversion (convertOp)) {
1046+ if (isa<DotOperandEncodingAttr>(targetType. getEncoding ()))
10821047 return ;
1083- }
1084-
10851048 Value oldV = convertOp.getSrc ();
10861049 LDBG (" check backward remat with source " << oldV << " encoding "
10871050 << targetType.getEncoding ());
@@ -1120,10 +1083,11 @@ void LayoutRematerialization::backwardRematerialization(
11201083// of the convert.
11211084void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast (
11221085 ConvertLayoutOp convertOp) {
1086+ // we don't handle conversions to DotOperandEncodingAttr
1087+ // this is a heuristics to accommodate fused attention
11231088 RankedTensorType targetType = convertOp.getType ();
1124- if (! shouldPropagateConversion (convertOp)) {
1089+ if (isa<DotOperandEncodingAttr>(targetType. getEncoding ()))
11251090 return ;
1126- }
11271091
11281092 auto isExtOrBroadcastOp = [](Operation *op) {
11291093 if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp, BroadcastOp,
0 commit comments