2525#include " triton/Dialect/TritonGPU/IR/Dialect.h"
2626#include " triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
2727#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
28+ #include " triton/Tools/Sys/GetEnv.hpp"
2829#include " llvm/ADT/TypeSwitch.h"
2930#include < deque>
3031
@@ -1101,6 +1102,10 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11011102 DenseMap<Value, Attribute> &layout,
11021103 ConvertLayoutOp convertOp,
11031104 IRMapping &mapping) {
1105+ std::optional<bool > enableForLoopSupport =
1106+ mlir::triton::tools::isEnvValueBool (mlir::triton::tools::getStrEnv (
1107+ " TRITON_INTEL_REMOVELAYOUTCONVERSION_SUPPORT_FOR_LOOP" ));
1108+
11041109 SetVector<Operation *> opsToRewrite;
11051110 // Keep track of yield operands that need to be duplicated.
11061111 DenseMap<Operation *, SmallVector<int >> yieldOperandsMap;
@@ -1126,12 +1131,13 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11261131 opsToRewrite.insert (ifOp.elseYield ().getOperation ());
11271132 yieldOperandsMap[ifOp.elseYield ()].push_back (operandIdx);
11281133 }
1129- if (auto forOp = v.getDefiningOp <scf::ForOp>()) {
1130- unsigned operandIdx = cast<OpResult>(v).getResultNumber ();
1131- auto yieldOp = forOp.getBody ()->getTerminator ();
1132- yieldOperandsMap[yieldOp].push_back (operandIdx);
1133- opsToRewrite.insert (yieldOp);
1134- }
1134+ if (enableForLoopSupport)
1135+ if (auto forOp = v.getDefiningOp <scf::ForOp>()) {
1136+ unsigned operandIdx = cast<OpResult>(v).getResultNumber ();
1137+ auto yieldOp = forOp.getBody ()->getTerminator ();
1138+ yieldOperandsMap[yieldOp].push_back (operandIdx);
1139+ opsToRewrite.insert (yieldOp);
1140+ }
11351141 } else {
11361142 BlockArgument blockArg = cast<BlockArgument>(v);
11371143 Operation *parentOp = blockArg.getOwner ()->getParentOp ();
@@ -1155,17 +1161,19 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11551161 IRRewriter builder (slice.begin ()->getContext ());
11561162 for (Operation *op : opsToRewrite) {
11571163 if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1158- // Construct the new initialization argument by adding yielded operands
1159- // that have been remapped.
1160- auto yieldOp = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
1161- auto yieldOperands = llvm::to_vector (yieldOp.getOperands ());
1162- SmallVector<int > operandsToRewrite = yieldOperandsMap[yieldOp];
1163- std::sort (operandsToRewrite.begin (), operandsToRewrite.end ());
11641164 SmallVector<Value> newOperands;
1165- for (int operandIdx : operandsToRewrite) {
1166- Value yieldOperand = yieldOp.getOperand (operandIdx);
1167- if (mapping.contains (yieldOperand))
1168- newOperands.push_back (mapping.lookup (yieldOperand));
1165+ if (enableForLoopSupport) {
1166+ // Construct the new initialization argument by adding yielded operands
1167+ // that have been remapped.
1168+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
1169+ auto yieldOperands = llvm::to_vector (yieldOp.getOperands ());
1170+ SmallVector<int > operandsToRewrite = yieldOperandsMap[yieldOp];
1171+ std::sort (operandsToRewrite.begin (), operandsToRewrite.end ());
1172+ for (int operandIdx : operandsToRewrite) {
1173+ Value yieldOperand = yieldOp.getOperand (operandIdx);
1174+ if (mapping.contains (yieldOperand))
1175+ newOperands.push_back (mapping.lookup (yieldOperand));
1176+ }
11691177 }
11701178
11711179 // Keep a mapping of the operands index to the new operands index.
@@ -1183,17 +1191,19 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11831191 scf::ForOp newForOp = replaceForOpWithNewSignature (
11841192 builder, forOp, newOperands, replacements);
11851193
1186- // Add rematerializations for loop results in the slice.
1187- unsigned oldIdx = 0 ;
1188- unsigned newIdx = forOp.getNumResults ();
1189- for (auto res : forOp.getResults ()) {
1190- if (slice.count (res)) {
1191- mapping.map (forOp.getResult (oldIdx), newForOp.getResult (newIdx));
1192- addRematValue (forOp.getResult (oldIdx), layout[res],
1193- newForOp.getResult (newIdx));
1194- ++newIdx;
1194+ if (enableForLoopSupport) {
1195+ // Add rematerializations for loop results in the slice.
1196+ unsigned oldIdx = 0 ;
1197+ unsigned newIdx = forOp.getNumResults ();
1198+ for (auto res : forOp.getResults ()) {
1199+ if (slice.count (res)) {
1200+ mapping.map (forOp.getResult (oldIdx), newForOp.getResult (newIdx));
1201+ addRematValue (forOp.getResult (oldIdx), layout[res],
1202+ newForOp.getResult (newIdx));
1203+ ++newIdx;
1204+ }
1205+ ++oldIdx;
11951206 }
1196- ++oldIdx;
11971207 }
11981208
11991209 deadOps.push_back (forOp.getOperation ());
0 commit comments