From dd853ccc572be76d8318b0947efdbf18c7c2170d Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Mon, 6 Oct 2025 16:50:06 +0000 Subject: [PATCH 1/7] Fix Signed-off-by: Ettore Tiotto --- .../RemoveLayoutConversions.cpp | 224 +++++++++++++++--- 1 file changed, 189 insertions(+), 35 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 278b1c8eae..8aeda5f552 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -3,8 +3,10 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -183,7 +185,9 @@ class LayoutRematerialization { void LayoutRematerialization::addRematValue(Value old, Attribute encoding, Value newV) { - LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); + LDBG("addRematValue for: " << old); + LDBG(" encoding: " << encoding); + LDBG(" new: " << newV); rematMapping[{old, encoding}] = newV; mappedValues[old] = encoding; } @@ -1101,6 +1105,11 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp, IRMapping &mapping) { + + llvm::errs() << "slice:\n"; + for (Value v : slice) + llvm::errs() << v << "\n"; + SetVector opsToRewrite; // Keep track of yield operands that need to be duplicated. DenseMap> yieldOperandsMap; @@ -1127,8 +1136,9 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx); } if (auto forOp = v.getDefiningOp()) { + llvm::errs() << "at line: " << __LINE__ << "\n"; unsigned operandIdx = cast(v).getResultNumber(); - auto yieldOp = forOp.getBody()->getTerminator(); + Operation *yieldOp = forOp.getBody()->getTerminator(); yieldOperandsMap[yieldOp].push_back(operandIdx); opsToRewrite.insert(yieldOp); } @@ -1147,6 +1157,11 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, slice.set_subtract(valuesWithExistingRemat); opsToRewrite = multiRootTopologicalSort(opsToRewrite); + llvm::errs() << "opsToRewrite:\n"; + for (Operation *op : opsToRewrite) { + llvm::errs().indent(2) << *op << "\n"; + } + // replaceAllUsesWith calls delayed until after initial rewrite. // This is required for slice.count(value) to work mid rewrite. SmallVector> replacements; @@ -1154,18 +1169,53 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, SmallVector deadOps; IRRewriter builder(slice.begin()->getContext()); for (Operation *op : opsToRewrite) { + llvm::errs() << "Processing:\n"; + llvm::errs().indent(2) << *op << "\n"; + + llvm::errs() << "mapping:\n"; + if (mapping.getValueMap().values().empty()) + llvm::errs().indent(2) << "Values: empty" << "\n"; + else { + llvm::errs().indent(2) << "Values:\n"; + for (auto pair : mapping.getValueMap()) { + auto first = pair.first; + auto second = pair.second; + llvm::errs().indent(4); + first.printAsOperand(llvm::errs(), {}); + llvm::errs() << " -> "; + second.printAsOperand(llvm::errs(), {}); + llvm::errs() << "\n"; + } + } + if (mapping.getOperationMap().values().empty()) + llvm::errs().indent(2) << "Operations: empty" << "\n"; + else { + llvm::errs().indent(2) << "Operations:\n"; + for (auto pair : mapping.getOperationMap()) { + auto first = pair.first; + auto second = pair.second; + llvm::errs().indent(4) << *first << "\n"; + llvm::errs().indent(4) << " -> " << *second << "\n"; + } + } + // assert(mapping.getBlockMap().values().empty()); + if (auto forOp = dyn_cast(op)) { - // Construct the new initialization argument by adding yielded operands + llvm::errs() << "at line: " << __LINE__ << "\n"; + // Construct the new init argument list by adding yielded operands // that have been remapped. auto yieldOp = cast(forOp.getBody()->getTerminator()); auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + SmallVector operandsToRewrite = yieldOperandsMap[yieldOp]; std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); SmallVector newOperands; for (int operandIdx : operandsToRewrite) { Value yieldOperand = yieldOp.getOperand(operandIdx); - if (mapping.contains(yieldOperand)) + if (mapping.contains(yieldOperand)) { + llvm::errs() << "at line: " << __LINE__ << "\n"; newOperands.push_back(mapping.lookup(yieldOperand)); + } } // Keep a mapping of the operands index to the new operands index. @@ -1179,21 +1229,41 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, newOperands.push_back(mapping.lookup(initVal.get())); } } + + if (newOperands.empty()) + continue; + // Create a new for loop with the new operands. scf::ForOp newForOp = replaceForOpWithNewSignature( builder, forOp, newOperands, replacements); + llvm::errs() << "at line: " << __LINE__ << "\n"; + llvm::errs() << "newForOp: "; + newForOp->dumpPretty(); + llvm::errs() << "\n"; + // Add rematerializations for loop results in the slice. - unsigned oldIdx = 0; - unsigned newIdx = forOp.getNumResults(); - for (auto res : forOp.getResults()) { - if (slice.count(res)) { - mapping.map(forOp.getResult(oldIdx), newForOp.getResult(newIdx)); - addRematValue(forOp.getResult(oldIdx), layout[res], - newForOp.getResult(newIdx)); - ++newIdx; + if (newForOp->getNumResults() > forOp.getNumResults()) { + unsigned oldIdx = 0; + unsigned newIdx = forOp.getNumResults(); + for (auto res : forOp.getResults()) { + if (slice.count(res)) { + llvm::errs() << "oldIdx: " << oldIdx << "\n"; + llvm::errs() << "res: " << res << "\n"; + Value oldRes = forOp.getResult(oldIdx); + + llvm::errs() << "newIdx: " << newIdx << "\n"; + llvm::errs() << "oldRes: " << oldRes << "\n"; + Value newRes = newForOp.getResult(newIdx); + llvm::errs() << "newRes: " << newRes << "\n"; + + mapping.map(forOp.getResult(oldIdx), newForOp.getResult(newIdx)); + addRematValue(forOp.getResult(oldIdx), layout[res], + newForOp.getResult(newIdx)); + ++newIdx; + } + ++oldIdx; } - ++oldIdx; } deadOps.push_back(forOp.getOperation()); @@ -1256,15 +1326,91 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, } builder.setInsertionPoint(op); if (auto yieldOp = dyn_cast(op)) { + llvm::errs() << "at line: " << __LINE__ << "\n"; auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); SmallVector operandsToRewrite = yieldOperandsMap[op]; // Sort so that operands are added in the same order as the new scf // results/arguments. std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); for (int operandIdx : operandsToRewrite) { - yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx))); + Value yieldOperand = yieldOp.getOperand(operandIdx); + yieldOperands.push_back(mapping.lookup(yieldOperand)); + } + [[maybe_unused]] auto newYieldOp = + builder.create(op->getLoc(), yieldOperands); + llvm::errs() << "newYieldOp:" << newYieldOp << "\n"; + + auto parentOp = newYieldOp->getParentOp(); + llvm::errs() << "parentOp:"; + parentOp->dumpPretty(); + llvm::errs() << "\n"; + + // Fixup the init argument list of the parent loop if necessary. + if (auto forOp = dyn_cast(parentOp)) { + unsigned numIterArgs = forOp.getRegionIterArgs().size(); + unsigned numYieldOperands = newYieldOp->getNumOperands(); + assert(numIterArgs <= numYieldOperands); + + if (numIterArgs < numYieldOperands) { + // We have more yield operands that loop initialization arguments. + // Create new "dummy" initialization arguments for loop. + SmallVector newOperands; + for (unsigned idx = numIterArgs; idx < numYieldOperands; ++idx) { + Value operand = newYieldOp->getOperand(idx); + Type operandTy = operand.getType(); + auto insertPt = builder.saveInsertionPoint(); + builder.setInsertionPoint(forOp->getPrevNode()); + auto constantOp = builder.create( + builder.getUnknownLoc(), operandTy, + builder.getZeroAttr(operandTy)); + builder.restoreInsertionPoint(insertPt); + llvm::errs() << "constantOp: " << *constantOp << "\n"; + newOperands.push_back(constantOp); + ++idx; + } + + scf::ForOp newForOp = replaceForOpWithNewSignature( + builder, forOp, newOperands, replacements); + + llvm::errs() << "newForOp:\n"; + newForOp->dumpPretty(); + llvm::errs() << "\n"; + deadOps.push_back(forOp.getOperation()); + + // Add rematerializations for loop results in the slice. + if (newForOp->getNumResults() > forOp.getNumResults()) { + unsigned oldIdx = 0; + unsigned newIdx = forOp.getNumResults(); + for (auto res : forOp.getResults()) { + llvm::errs() << "at line: " << __LINE__ << "\n"; + llvm::errs() << "res: " << res << "\n"; + + llvm::errs() << "slice:\n"; + for (Value v : slice) + llvm::errs() << v << "\n"; + + if (slice.count(res)) { + llvm::errs() << "oldIdx: " << oldIdx << "\n"; + llvm::errs() << "res: " << res << "\n"; + Value oldRes = forOp.getResult(oldIdx); + + llvm::errs() << "newIdx: " << newIdx << "\n"; + llvm::errs() << "oldRes: " << oldRes << "\n"; + Value newRes = newForOp.getResult(newIdx); + llvm::errs() << "newRes: " << newRes << "\n"; + + mapping.map(forOp.getResult(oldIdx), + newForOp.getResult(newIdx)); + addRematValue(forOp.getResult(oldIdx), layout[res], + newForOp.getResult(newIdx)); + ++newIdx; + } + ++oldIdx; + } + } + } } - builder.create(op->getLoc(), yieldOperands); + op->erase(); continue; } @@ -1279,12 +1425,16 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, cvt.getResult()); continue; } + Operation *newOp = builder.clone(*op, mapping); for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { + llvm::errs() << "old: " << old << "\n"; + llvm::errs() << "newV: " << newV << "\n"; auto it = layout.find(old); if (it == layout.end()) continue; Type oldType = old.getType(); + llvm::errs() << "oldType: " << oldType << "\n"; Type newType; if (isTensorPointerType(oldType)) { auto ptrType = cast(oldType); @@ -1296,11 +1446,13 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, newType = cast(old.getType()).cloneWithEncoding(it->second); } + llvm::errs() << "newType: " << newType << "\n"; newV.setType(newType); addRematValue(old, it->second, newV); } } - // Check mapping and see if there are existing convertOps on the old Argument + // Check mapping and see if there are existing convertOps on the old + // Argument convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc())); opToDelete.insert(convertOp); @@ -1470,8 +1622,8 @@ void LayoutRematerialization::backwardRematerialization( Value oldV = convertOp.getSrc(); LDBG("check backward remat with source " << oldV << " encoding " << targetType.getEncoding()); - // Check to see if there are existing remat'ed values for the pair of oldValue - // and encoding. Make sure it dominates the current conversion. + // Check to see if there are existing remat'ed values for the pair of + // oldValue and encoding. Make sure it dominates the current conversion. Value newV = getRematValue(oldV, targetType.getEncoding()); if (newV && domInfo.properlyDominates(newV, convertOp)) { // Replace it with the remat'ed value. @@ -1583,9 +1735,10 @@ void LayoutRematerialization::backwardRematerialization( auto reduceOp = dyn_cast(op); ReduceOpHelper helper(reduceOp); if (!helper.isAssociative()) { - // We shouldn't rematerize a no associative reduce op if it has multiple - // use chain. - LDBG(" skipped rematerialization due to non-associative reduce in the " + // We shouldn't rematerize a no associative reduce op if it has + // multiple use chain. + LDBG(" skipped rematerialization due to non-associative reduce in " + "the " "slice"); return; } @@ -1669,14 +1822,14 @@ void LayoutRematerialization::hoistConvertDotOperand( if (!canBePipelined(convertOp)) return; - // We hoist over any operation that can be done without data movement between - // threads We do views and elementwise pure ops for now + // We hoist over any operation that can be done without data movement + // between threads We do views and elementwise pure ops for now auto noDataMovement = [](Operation *op) { return (op->hasTrait() && isMemoryEffectFree(op)) || isa(op) || isView(op); }; - // Stop the slice as soon as we find an operation that cannot be done without - // data movement between threads + // Stop the slice as soon as we find an operation that cannot be done + // without data movement between threads auto stop = std::not_fn(noDataMovement); SetVector slice; @@ -1736,8 +1889,8 @@ void LayoutRematerialization::hoistConvertDotOperand( rewriteSlice(innerSlice, layout, convertOp, mapping); } -// For convert left we try to hoist them above type extension to reduce the cost -// of the convert. +// For convert left we try to hoist them above type extension to reduce the +// cost of the convert. void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { // DotOperand is hoisted by hoistDotOperand @@ -1839,7 +1992,8 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( void LayoutRematerialization::hoistConvertIntoConditionals( ConvertLayoutOp convertOp) { // Take the backward slice of tensor dependencies rooted at the conversion, - // stopping at conditionals. This subslice is used to initialize the analysis. + // stopping at conditionals. This subslice is used to initialize the + // analysis. SetVector slice; DenseMap layout; auto isIfOp = [](Operation *op) { return isa(op); }; @@ -1848,17 +2002,17 @@ void LayoutRematerialization::hoistConvertIntoConditionals( layout, isIfOp))) return; - // These are the conditional edges above which conversions should be hoisted. - // The value represents the `scf.if` op result and the operand represents the - // edge into one of the branches. + // These are the conditional edges above which conversions should be + // hoisted. The value represents the `scf.if` op result and the operand + // represents the edge into one of the branches. SmallVector> hoistAbove; - // The list of `scf.if` op results in the slice that are not rematerializable. - // Hoisting is terminated at these values. + // The list of `scf.if` op results in the slice that are not + // rematerializable. Hoisting is terminated at these values. SmallVector terminals; - // This loop recurses through the subslices of the backwards dependencies, so - // re-query the size of `slice`. + // This loop recurses through the subslices of the backwards dependencies, + // so re-query the size of `slice`. for (unsigned i = 0; i != slice.size(); ++i) { Value v = slice[i]; auto ifOp = v.getDefiningOp(); From accb99c2474e0641d5af5bdc85ce064e82d047e2 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Mon, 6 Oct 2025 19:32:10 +0000 Subject: [PATCH 2/7] Fix Signed-off-by: Ettore Tiotto --- .../RemoveLayoutConversions.cpp | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 8aeda5f552..9137faed89 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -1204,18 +1204,15 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, llvm::errs() << "at line: " << __LINE__ << "\n"; // Construct the new init argument list by adding yielded operands // that have been remapped. + SmallVector newOperands; auto yieldOp = cast(forOp.getBody()->getTerminator()); auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); - SmallVector operandsToRewrite = yieldOperandsMap[yieldOp]; std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); - SmallVector newOperands; for (int operandIdx : operandsToRewrite) { Value yieldOperand = yieldOp.getOperand(operandIdx); - if (mapping.contains(yieldOperand)) { - llvm::errs() << "at line: " << __LINE__ << "\n"; + if (mapping.contains(yieldOperand)) newOperands.push_back(mapping.lookup(yieldOperand)); - } } // Keep a mapping of the operands index to the new operands index. @@ -1249,12 +1246,10 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, for (auto res : forOp.getResults()) { if (slice.count(res)) { llvm::errs() << "oldIdx: " << oldIdx << "\n"; - llvm::errs() << "res: " << res << "\n"; - Value oldRes = forOp.getResult(oldIdx); - llvm::errs() << "newIdx: " << newIdx << "\n"; - llvm::errs() << "oldRes: " << oldRes << "\n"; + Value oldRes = forOp.getResult(oldIdx); Value newRes = newForOp.getResult(newIdx); + llvm::errs() << "oldRes: " << oldRes << "\n"; llvm::errs() << "newRes: " << newRes << "\n"; mapping.map(forOp.getResult(oldIdx), newForOp.getResult(newIdx)); @@ -1274,9 +1269,13 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, mapping.map(loopBody.getArgument(m.first + numIndVars), loopBody.getArgument(m.second + numIndVars)); LLVM_DEBUG({ - DBGS() << "mapping forOp " - << loopBody.getArgument(m.first + numIndVars) << " to " - << loopBody.getArgument(m.second + numIndVars) << '\n'; + DBGS() << "mapping forOp "; + loopBody.getArgument(m.first + numIndVars) + .printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << " to "; + loopBody.getArgument(m.second + numIndVars) + .printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << '\n'; }); // The result is not in the layout/slice, the argument is. Value oldArg = loopBody.getArgument(m.first + numIndVars); @@ -1287,6 +1286,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, } continue; } + if (auto ifOp = dyn_cast(op)) { SmallVector newTypes; for (auto res : ifOp.getResults()) { @@ -1324,6 +1324,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, deadOps.push_back(ifOp.getOperation()); continue; } + builder.setInsertionPoint(op); if (auto yieldOp = dyn_cast(op)) { llvm::errs() << "at line: " << __LINE__ << "\n"; @@ -1369,6 +1370,9 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, ++idx; } + if (newOperands.empty()) + continue; + scf::ForOp newForOp = replaceForOpWithNewSignature( builder, forOp, newOperands, replacements); @@ -1385,18 +1389,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, llvm::errs() << "at line: " << __LINE__ << "\n"; llvm::errs() << "res: " << res << "\n"; - llvm::errs() << "slice:\n"; - for (Value v : slice) - llvm::errs() << v << "\n"; - if (slice.count(res)) { llvm::errs() << "oldIdx: " << oldIdx << "\n"; - llvm::errs() << "res: " << res << "\n"; - Value oldRes = forOp.getResult(oldIdx); - llvm::errs() << "newIdx: " << newIdx << "\n"; - llvm::errs() << "oldRes: " << oldRes << "\n"; + Value oldRes = forOp.getResult(oldIdx); Value newRes = newForOp.getResult(newIdx); + llvm::errs() << "oldRes: " << oldRes << "\n"; llvm::errs() << "newRes: " << newRes << "\n"; mapping.map(forOp.getResult(oldIdx), From 247e2a7ca824f3ff79e467f216ca3d9fc0a24ab0 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Mon, 6 Oct 2025 20:01:25 +0000 Subject: [PATCH 3/7] Fix Signed-off-by: Ettore Tiotto --- .../RemoveLayoutConversions.cpp | 115 +++++++++--------- 1 file changed, 60 insertions(+), 55 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 9137faed89..d4f8744eb7 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -1105,10 +1105,11 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp, IRMapping &mapping) { - - llvm::errs() << "slice:\n"; - for (Value v : slice) - llvm::errs() << v << "\n"; + LLVM_DEBUG({ + llvm::errs() << "slice:\n"; + for (Value v : slice) + llvm::errs() << v << "\n"; + }); SetVector opsToRewrite; // Keep track of yield operands that need to be duplicated. @@ -1136,9 +1137,8 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx); } if (auto forOp = v.getDefiningOp()) { - llvm::errs() << "at line: " << __LINE__ << "\n"; unsigned operandIdx = cast(v).getResultNumber(); - Operation *yieldOp = forOp.getBody()->getTerminator(); + auto yieldOp = forOp.getBody()->getTerminator(); yieldOperandsMap[yieldOp].push_back(operandIdx); opsToRewrite.insert(yieldOp); } @@ -1157,10 +1157,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, slice.set_subtract(valuesWithExistingRemat); opsToRewrite = multiRootTopologicalSort(opsToRewrite); - llvm::errs() << "opsToRewrite:\n"; - for (Operation *op : opsToRewrite) { - llvm::errs().indent(2) << *op << "\n"; - } + LLVM_DEBUG({ + llvm::errs() << "opsToRewrite:\n"; + for (Operation *op : opsToRewrite) { + llvm::errs().indent(2) << *op << "\n"; + } + }); // replaceAllUsesWith calls delayed until after initial rewrite. // This is required for slice.count(value) to work mid rewrite. @@ -1169,39 +1171,40 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, SmallVector deadOps; IRRewriter builder(slice.begin()->getContext()); for (Operation *op : opsToRewrite) { - llvm::errs() << "Processing:\n"; - llvm::errs().indent(2) << *op << "\n"; - - llvm::errs() << "mapping:\n"; - if (mapping.getValueMap().values().empty()) - llvm::errs().indent(2) << "Values: empty" << "\n"; - else { - llvm::errs().indent(2) << "Values:\n"; - for (auto pair : mapping.getValueMap()) { - auto first = pair.first; - auto second = pair.second; - llvm::errs().indent(4); - first.printAsOperand(llvm::errs(), {}); - llvm::errs() << " -> "; - second.printAsOperand(llvm::errs(), {}); - llvm::errs() << "\n"; + LLVM_DEBUG({ + llvm::errs() << "Processing:\n"; + llvm::errs().indent(2) << *op << "\n"; + + llvm::errs() << "mapping:\n"; + if (mapping.getValueMap().values().empty()) + llvm::errs().indent(2) << "Values: empty" << "\n"; + else { + llvm::errs().indent(2) << "Values:\n"; + for (auto pair : mapping.getValueMap()) { + auto first = pair.first; + auto second = pair.second; + llvm::errs().indent(4); + first.printAsOperand(llvm::errs(), {}); + llvm::errs() << " -> "; + second.printAsOperand(llvm::errs(), {}); + llvm::errs() << "\n"; + } } - } - if (mapping.getOperationMap().values().empty()) - llvm::errs().indent(2) << "Operations: empty" << "\n"; - else { - llvm::errs().indent(2) << "Operations:\n"; - for (auto pair : mapping.getOperationMap()) { - auto first = pair.first; - auto second = pair.second; - llvm::errs().indent(4) << *first << "\n"; - llvm::errs().indent(4) << " -> " << *second << "\n"; + if (mapping.getOperationMap().values().empty()) + llvm::errs().indent(2) << "Operations: empty" << "\n"; + else { + llvm::errs().indent(2) << "Operations:\n"; + for (auto pair : mapping.getOperationMap()) { + auto first = pair.first; + auto second = pair.second; + llvm::errs().indent(4) << *first << "\n"; + llvm::errs().indent(4) << " -> " << *second << "\n"; + } } - } - // assert(mapping.getBlockMap().values().empty()); + // assert(mapping.getBlockMap().values().empty()); + }); if (auto forOp = dyn_cast(op)) { - llvm::errs() << "at line: " << __LINE__ << "\n"; // Construct the new init argument list by adding yielded operands // that have been remapped. SmallVector newOperands; @@ -1234,10 +1237,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, scf::ForOp newForOp = replaceForOpWithNewSignature( builder, forOp, newOperands, replacements); - llvm::errs() << "at line: " << __LINE__ << "\n"; - llvm::errs() << "newForOp: "; - newForOp->dumpPretty(); - llvm::errs() << "\n"; + LLVM_DEBUG({ + llvm::errs() << "at line: " << __LINE__ << "\n"; + llvm::errs() << "newForOp: "; + newForOp->dumpPretty(); + llvm::errs() << "\n"; + }); // Add rematerializations for loop results in the slice. if (newForOp->getNumResults() > forOp.getNumResults()) { @@ -1245,12 +1250,16 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, unsigned newIdx = forOp.getNumResults(); for (auto res : forOp.getResults()) { if (slice.count(res)) { - llvm::errs() << "oldIdx: " << oldIdx << "\n"; - llvm::errs() << "newIdx: " << newIdx << "\n"; + LLVM_DEBUG({ + llvm::errs() << "oldIdx: " << oldIdx << "\n"; + llvm::errs() << "newIdx: " << newIdx << "\n"; + }); Value oldRes = forOp.getResult(oldIdx); Value newRes = newForOp.getResult(newIdx); - llvm::errs() << "oldRes: " << oldRes << "\n"; - llvm::errs() << "newRes: " << newRes << "\n"; + LLVM_DEBUG({ + llvm::errs() << "oldRes: " << oldRes << "\n"; + llvm::errs() << "newRes: " << newRes << "\n"; + }); mapping.map(forOp.getResult(oldIdx), newForOp.getResult(newIdx)); addRematValue(forOp.getResult(oldIdx), layout[res], @@ -1327,18 +1336,17 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, builder.setInsertionPoint(op); if (auto yieldOp = dyn_cast(op)) { - llvm::errs() << "at line: " << __LINE__ << "\n"; auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); SmallVector operandsToRewrite = yieldOperandsMap[op]; // Sort so that operands are added in the same order as the new scf // results/arguments. std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); for (int operandIdx : operandsToRewrite) { - Value yieldOperand = yieldOp.getOperand(operandIdx); - yieldOperands.push_back(mapping.lookup(yieldOperand)); + yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx))); } [[maybe_unused]] auto newYieldOp = builder.create(op->getLoc(), yieldOperands); +#if 0 llvm::errs() << "newYieldOp:" << newYieldOp << "\n"; auto parentOp = newYieldOp->getParentOp(); @@ -1407,7 +1415,8 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, } } } - } + } +#endif op->erase(); continue; @@ -1426,13 +1435,10 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, Operation *newOp = builder.clone(*op, mapping); for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { - llvm::errs() << "old: " << old << "\n"; - llvm::errs() << "newV: " << newV << "\n"; auto it = layout.find(old); if (it == layout.end()) continue; Type oldType = old.getType(); - llvm::errs() << "oldType: " << oldType << "\n"; Type newType; if (isTensorPointerType(oldType)) { auto ptrType = cast(oldType); @@ -1444,7 +1450,6 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, newType = cast(old.getType()).cloneWithEncoding(it->second); } - llvm::errs() << "newType: " << newType << "\n"; newV.setType(newType); addRematValue(old, it->second, newV); } From a55c0fc2768054e31b562b14f6c40f934ffd46bb Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Tue, 7 Oct 2025 13:48:04 +0000 Subject: [PATCH 4/7] Fix Signed-off-by: Ettore Tiotto --- test/TritonIntelGPU/combine.mlir | 52 +++++++++++++ .../RemoveLayoutConversions.cpp | 78 ++++++++++++++----- 2 files changed, 112 insertions(+), 18 deletions(-) diff --git a/test/TritonIntelGPU/combine.mlir b/test/TritonIntelGPU/combine.mlir index 51ccc2c0b2..8e8cb65189 100644 --- a/test/TritonIntelGPU/combine.mlir +++ b/test/TritonIntelGPU/combine.mlir @@ -3487,3 +3487,55 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.return %12, %9#1 : tensor<4x1xi64, #blocked>, i32 } } + +// ----- + +// COM: Reproducer for issue 5251. +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK: test_5251 + tt.func public @test_5251(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<2147483647> : tensor<1x4xi32, #blocked> + %cst_0 = arith.constant dense<0x7F800000> : tensor<1x4xf32, #blocked> + %0 = tt.splat %arg3 : i32 -> tensor<1x4xi32, #blocked> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<1x4x!tt.ptr, #blocked> + %2:2 = scf.for %arg4 = %c0_i32 to %arg3 step %c4_i32 iter_args(%arg5 = %cst_0, %arg6 = %cst) -> (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) : i32 { + // CHECK: [[RES:%.*]] = scf.for + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.yield + %12 = tt.splat %arg4 : i32 -> tensor<1x4xi32, #blocked> + %13 = arith.cmpi slt, %12, %0 : tensor<1x4xi32, #blocked> + %14 = ttg.convert_layout %1 : tensor<1x4x!tt.ptr, #blocked> -> tensor<1x4x!tt.ptr, #blocked> + %15 = tt.load %14 : tensor<1x4x!tt.ptr, #blocked> + %16 = arith.cmpi slt, %arg6, %12 : tensor<1x4xi32, #blocked> + %17 = arith.select %16, %arg5, %15 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %18 = arith.select %13, %17, %arg5 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + scf.yield %18, %arg6 : tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked> + } + // CHECK: [[RED:%.*]]:2 = "tt.reduce"([[RES]], %cst) + %3:2 = "tt.reduce"(%2#0, %2#1) <{axis = 1 : i32}> ({ + ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): + %12 = arith.cmpf olt, %arg4, %arg6 : f32 + %13 = arith.select %12, %arg4, %arg6 : f32 + %14 = arith.select %12, %arg5, %arg7 : i32 + tt.reduce.return %13, %14 : f32, i32 + }) : (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) -> (tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED]]#1 + %4 = ttg.convert_layout %3#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> + %7 = ttg.convert_layout %6 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked3> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked3> + %9 = arith.extsi %7 : tensor<1x1xi32, #blocked3> to tensor<1x1xi64, #blocked3> + %10 = ttg.convert_layout %8 : tensor<1x1x!tt.ptr, #blocked3> -> tensor<1x1x!tt.ptr, #blocked3> + %11 = ttg.convert_layout %9 : tensor<1x1xi64, #blocked3> -> tensor<1x1xi64, #blocked3> + tt.store %10, %11 : tensor<1x1x!tt.ptr, #blocked3> + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index d4f8744eb7..048cc0e846 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -27,6 +27,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include @@ -1140,19 +1141,36 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, unsigned operandIdx = cast(v).getResultNumber(); auto yieldOp = forOp.getBody()->getTerminator(); yieldOperandsMap[yieldOp].push_back(operandIdx); + LLVM_DEBUG({ + llvm::errs() << "1. pushing " << operandIdx + << " in yieldOperandMap\n "; + }); opsToRewrite.insert(yieldOp); } - } else { + } +#if 1 + else { + // These are already pushed above. BlockArgument blockArg = cast(v); Operation *parentOp = blockArg.getOwner()->getParentOp(); if (auto loopOp = cast(parentOp)) { opsToRewrite.insert(loopOp.getOperation()); OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg); + unsigned opIdx = operand->getOperandNumber(); auto yieldOp = blockArg.getOwner()->getTerminator(); - yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber()); + if (!yieldOperandsMap.contains(yieldOp)) + yieldOperandsMap[yieldOp].push_back(opIdx); + else if (llvm::none_of(yieldOperandsMap[yieldOp], + [&](unsigned idx) { return idx == opIdx; })) + yieldOperandsMap[yieldOp].push_back(opIdx); + LLVM_DEBUG({ + llvm::errs() << "2. pushing " << operand->getOperandNumber() + << " in yieldOperandMap\n"; + }); opsToRewrite.insert(yieldOp); } } +#endif } slice.set_subtract(valuesWithExistingRemat); opsToRewrite = multiRootTopologicalSort(opsToRewrite); @@ -1162,6 +1180,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, for (Operation *op : opsToRewrite) { llvm::errs().indent(2) << *op << "\n"; } + llvm::errs() << "yieldOperandsMap:\n"; + for (auto entry : yieldOperandsMap) { + llvm::errs() << *entry.first << " -> \n"; + for (int opx : entry.second) + llvm::errs().indent(2) << opx << "\n"; + } }); // replaceAllUsesWith calls delayed until after initial rewrite. @@ -1214,8 +1238,11 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); for (int operandIdx : operandsToRewrite) { Value yieldOperand = yieldOp.getOperand(operandIdx); - if (mapping.contains(yieldOperand)) + if (mapping.contains(yieldOperand)) { newOperands.push_back(mapping.lookup(yieldOperand)); + // llvm::errs() << "YieldOperand: " << yieldOperand + // << " is mapped, adding new init to for loop\n"; + } } // Keep a mapping of the operands index to the new operands index. @@ -1227,6 +1254,9 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, forOp.getTiedLoopResult(&initVal).getResultNumber(), forOp.getInitArgs().size() + newOperands.size())); newOperands.push_back(mapping.lookup(initVal.get())); + // llvm::errs() << "initVal: " << initVal.get() + // << " is mapped, adding new init to for + // loop\n"; } } @@ -1347,12 +1377,14 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, [[maybe_unused]] auto newYieldOp = builder.create(op->getLoc(), yieldOperands); #if 0 - llvm::errs() << "newYieldOp:" << newYieldOp << "\n"; auto parentOp = newYieldOp->getParentOp(); - llvm::errs() << "parentOp:"; - parentOp->dumpPretty(); - llvm::errs() << "\n"; + LLVM_DEBUG({ + llvm::errs() << "newYieldOp:" << newYieldOp << "\n"; + llvm::errs() << "parentOp:"; + parentOp->dumpPretty(); + llvm::errs() << "\n"; + }); // Fixup the init argument list of the parent loop if necessary. if (auto forOp = dyn_cast(parentOp)) { @@ -1373,7 +1405,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, builder.getUnknownLoc(), operandTy, builder.getZeroAttr(operandTy)); builder.restoreInsertionPoint(insertPt); - llvm::errs() << "constantOp: " << *constantOp << "\n"; + // llvm::errs() << "constantOp: " << *constantOp << "\n"; newOperands.push_back(constantOp); ++idx; } @@ -1384,9 +1416,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, scf::ForOp newForOp = replaceForOpWithNewSignature( builder, forOp, newOperands, replacements); - llvm::errs() << "newForOp:\n"; - newForOp->dumpPretty(); - llvm::errs() << "\n"; + LLVM_DEBUG({ + llvm::errs() << "newForOp:\n"; + newForOp->dumpPretty(); + llvm::errs() << "\n"; + }); + deadOps.push_back(forOp.getOperation()); // Add rematerializations for loop results in the slice. @@ -1394,16 +1429,23 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, unsigned oldIdx = 0; unsigned newIdx = forOp.getNumResults(); for (auto res : forOp.getResults()) { - llvm::errs() << "at line: " << __LINE__ << "\n"; - llvm::errs() << "res: " << res << "\n"; + LLVM_DEBUG({ + llvm::errs() << "at line: " << __LINE__ << "\n"; + llvm::errs() << "res: " << res << "\n"; + }); if (slice.count(res)) { - llvm::errs() << "oldIdx: " << oldIdx << "\n"; - llvm::errs() << "newIdx: " << newIdx << "\n"; + LLVM_DEBUG({ + llvm::errs() << "oldIdx: " << oldIdx << "\n"; + llvm::errs() << "newIdx: " << newIdx << "\n"; + }); Value oldRes = forOp.getResult(oldIdx); Value newRes = newForOp.getResult(newIdx); - llvm::errs() << "oldRes: " << oldRes << "\n"; - llvm::errs() << "newRes: " << newRes << "\n"; + + LLVM_DEBUG({ + llvm::errs() << "oldRes: " << oldRes << "\n"; + llvm::errs() << "newRes: " << newRes << "\n"; + }); mapping.map(forOp.getResult(oldIdx), newForOp.getResult(newIdx)); @@ -1415,7 +1457,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, } } } - } + } #endif op->erase(); From c6f762a6b590d3fd2f2fe72a2d1eeaca200c355f Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Tue, 7 Oct 2025 17:43:05 +0000 Subject: [PATCH 5/7] Fix Signed-off-by: Ettore Tiotto --- .../RemoveLayoutConversions.cpp | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 048cc0e846..96533838d5 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -1,6 +1,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" @@ -1106,8 +1107,9 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp, IRMapping &mapping) { + auto mod = convertOp->getParentOfType(); LLVM_DEBUG({ - llvm::errs() << "slice:\n"; + llvm::errs() << "rewriteSlice:\n"; for (Value v : slice) llvm::errs() << v << "\n"; }); @@ -1178,7 +1180,8 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, LLVM_DEBUG({ llvm::errs() << "opsToRewrite:\n"; for (Operation *op : opsToRewrite) { - llvm::errs().indent(2) << *op << "\n"; + llvm::errs().indent(2) << "(" << op << "): "; + op->dumpPretty(); } llvm::errs() << "yieldOperandsMap:\n"; for (auto entry : yieldOperandsMap) { @@ -1197,7 +1200,9 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, for (Operation *op : opsToRewrite) { LLVM_DEBUG({ llvm::errs() << "Processing:\n"; - llvm::errs().indent(2) << *op << "\n"; + llvm::errs().indent(2) << "(" << op << "): "; + op->dumpPretty(); + llvm::errs() << "\n"; llvm::errs() << "mapping:\n"; if (mapping.getValueMap().values().empty()) @@ -1240,8 +1245,10 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, Value yieldOperand = yieldOp.getOperand(operandIdx); if (mapping.contains(yieldOperand)) { newOperands.push_back(mapping.lookup(yieldOperand)); - // llvm::errs() << "YieldOperand: " << yieldOperand - // << " is mapped, adding new init to for loop\n"; + LLVM_DEBUG({ + llvm::errs() << "YieldOperand: " << yieldOperand + << " is mapped, adding new init to for loop\n"; + }); } } @@ -1254,9 +1261,10 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, forOp.getTiedLoopResult(&initVal).getResultNumber(), forOp.getInitArgs().size() + newOperands.size())); newOperands.push_back(mapping.lookup(initVal.get())); - // llvm::errs() << "initVal: " << initVal.get() - // << " is mapped, adding new init to for - // loop\n"; + LLVM_DEBUG({ + llvm::errs() << "initVal: " << initVal.get() + << " is mapped, adding new init to for loop\n "; + }); } } @@ -1269,7 +1277,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, LLVM_DEBUG({ llvm::errs() << "at line: " << __LINE__ << "\n"; - llvm::errs() << "newForOp: "; + llvm::errs() << "newForOp (" << &newForOp << "): "; newForOp->dumpPretty(); llvm::errs() << "\n"; }); @@ -1279,6 +1287,9 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, unsigned oldIdx = 0; unsigned newIdx = forOp.getNumResults(); for (auto res : forOp.getResults()) { + if (newIdx >= newForOp.getNumResults()) + break; + if (slice.count(res)) { LLVM_DEBUG({ llvm::errs() << "oldIdx: " << oldIdx << "\n"; @@ -1376,16 +1387,20 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, } [[maybe_unused]] auto newYieldOp = builder.create(op->getLoc(), yieldOperands); -#if 0 auto parentOp = newYieldOp->getParentOp(); LLVM_DEBUG({ + unsigned numYieldArgsAdded = + newYieldOp.getNumOperands() - yieldOp.getNumOperands(); + llvm::errs() << "Added " << numYieldArgsAdded + << " operands to the loop yield\n"; llvm::errs() << "newYieldOp:" << newYieldOp << "\n"; llvm::errs() << "parentOp:"; parentOp->dumpPretty(); llvm::errs() << "\n"; }); +#if 1 // Fixup the init argument list of the parent loop if necessary. if (auto forOp = dyn_cast(parentOp)) { unsigned numIterArgs = forOp.getRegionIterArgs().size(); @@ -1417,7 +1432,11 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, builder, forOp, newOperands, replacements); LLVM_DEBUG({ - llvm::errs() << "newForOp:\n"; + unsigned numArgsAdded = + newForOp.getNumResults() - forOp.getNumResults(); + llvm::errs() << "Added " << numArgsAdded + << " arguments to the loop\n"; + llvm::errs() << "newForOp (" << &newForOp << "): "; newForOp->dumpPretty(); llvm::errs() << "\n"; }); @@ -1429,6 +1448,9 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, unsigned oldIdx = 0; unsigned newIdx = forOp.getNumResults(); for (auto res : forOp.getResults()) { + if (newIdx >= newForOp.getNumResults()) + break; + LLVM_DEBUG({ llvm::errs() << "at line: " << __LINE__ << "\n"; llvm::errs() << "res: " << res << "\n"; @@ -1508,6 +1530,11 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, for (Operation *op : deadOps) opToDelete.insert(op); + + LLVM_DEBUG({ + llvm::errs() << "rewriteSlice DONE:\n"; + mod->dump(); + }); } void LayoutRematerialization::rewriteSlice(SetVector &slice, From 075085e1f6edda81f99584aa99a5a99706e63423 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 17 Oct 2025 17:48:55 +0000 Subject: [PATCH 6/7] Add test cases Signed-off-by: Ettore Tiotto --- test/TritonIntelGPU/combine.mlir | 272 ++++++++++++++++++++++++++++++- 1 file changed, 269 insertions(+), 3 deletions(-) diff --git a/test/TritonIntelGPU/combine.mlir b/test/TritonIntelGPU/combine.mlir index d5e34ff3ed..91795281f5 100644 --- a/test/TritonIntelGPU/combine.mlir +++ b/test/TritonIntelGPU/combine.mlir @@ -3491,14 +3491,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- -// COM: Reproducer for issue 5251. #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK: test_5251 - tt.func public @test_5251(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + // CHECK: test_5251_1 + tt.func public @test_5251_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c4_i32 = arith.constant 4 : i32 %c0_i32 = arith.constant 0 : i32 %cst = arith.constant dense<2147483647> : tensor<1x4xi32, #blocked> @@ -3521,6 +3520,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // CHECK: [[RED:%.*]]:2 = "tt.reduce"([[RES]], %cst) %3:2 = "tt.reduce"(%2#0, %2#1) <{axis = 1 : i32}> ({ ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return %12 = arith.cmpf olt, %arg4, %arg6 : f32 %13 = arith.select %12, %arg4, %arg6 : f32 %14 = arith.select %12, %arg5, %arg7 : i32 @@ -3540,3 +3541,268 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: test_5251_2 + tt.func public @test_5251_2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32) { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<2147483647> : tensor<1x4xi32, #blocked> + %cst_0 = arith.constant dense<0x7F800000> : tensor<1x4xf32, #blocked> + %0 = tt.splat %arg2 : i32 -> tensor<1x4xi32, #blocked> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<1x4x!tt.ptr, #blocked> + %2:2 = scf.for %arg3 = %c0_i32 to %arg2 step %c4_i32 iter_args(%arg4 = %cst_0, %arg5 = %cst) -> (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) : i32 { + // CHECK: [[RES:%.*]] = scf.for + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.yield + %14 = tt.splat %arg3 : i32 -> tensor<1x4xi32, #blocked> + %15 = arith.cmpi slt, %14, %0 : tensor<1x4xi32, #blocked> + %16 = ttg.convert_layout %1 : tensor<1x4x!tt.ptr, #blocked> -> tensor<1x4x!tt.ptr, #blocked> + %17 = ttg.convert_layout %16 : tensor<1x4x!tt.ptr, #blocked> -> tensor<1x4x!tt.ptr, #blocked> + %18 = tt.load %17 : tensor<1x4x!tt.ptr, #blocked> + %19 = arith.cmpi slt, %arg5, %14 : tensor<1x4xi32, #blocked> + %20 = arith.select %19, %arg4, %18 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %21 = arith.select %15, %20, %arg4 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + scf.yield %21, %arg5 : tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked> + } + // CHECK: [[RED:%.*]]:2 = "tt.reduce"([[RES]], %cst) + %3:2 = "tt.reduce"(%2#0, %2#1) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %14 = arith.cmpf olt, %arg3, %arg5 : f32 + %15 = arith.select %14, %arg3, %arg5 : f32 + %16 = arith.select %14, %arg4, %arg6 : i32 + tt.reduce.return %15, %16 : f32, i32 + }) : (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) -> (tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED]]#1 + %4 = ttg.convert_layout %3#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> + %7 = ttg.convert_layout %6 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked3> + %8 = tt.splat %arg1 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked3> + %9 = arith.extsi %7 : tensor<1x1xi32, #blocked3> to tensor<1x1xi64, #blocked3> + %10 = ttg.convert_layout %8 : tensor<1x1x!tt.ptr, #blocked3> -> tensor<1x1x!tt.ptr, #blocked3> + %11 = ttg.convert_layout %9 : tensor<1x1xi64, #blocked3> -> tensor<1x1xi64, #blocked3> + %12 = ttg.convert_layout %10 : tensor<1x1x!tt.ptr, #blocked3> -> tensor<1x1x!tt.ptr, #blocked3> + %13 = ttg.convert_layout %11 : tensor<1x1xi64, #blocked3> -> tensor<1x1xi64, #blocked3> + tt.store %12, %13 : tensor<1x1x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block} { + // CHECK: test_5251_3 + tt.func public @test_5251_3(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32) attributes {noinline = false} { + // CHECK-NOT: ttg.convert_layout + %true = arith.constant true + %cst = arith.constant dense : tensor<1x4xi1, #blocked> + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<2147483647> : tensor<1x4xi32, #blocked> + %cst_1 = arith.constant dense<0x7F800000> : tensor<1x4xf32, #blocked> + %cst_2 = arith.constant dense<0xFF800000> : tensor<1x4xf32, #blocked> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<1x4xf32, #blocked> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<4xi32, #blocked1> -> tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x4xi32, #blocked2> + %3 = ttg.convert_layout %2 : tensor<1x4xi32, #blocked2> -> tensor<1x4xi32, #blocked> + %4 = tt.splat %arg6 : i32 -> tensor<1x4xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1x4x!tt.ptr, #blocked> + %6:7 = scf.for %arg7 = %c0_i32 to %arg6 step %c4_i32 iter_args(%arg8 = %cst_3, %arg9 = %cst_2, %arg10 = %cst_1, %arg11 = %cst_2, %arg12 = %cst_0, %arg13 = %cst_1, %arg14 = %cst_0) -> (tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) : i32 { + // CHECK: [[RES:%.*]]:7 = scf.for + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.yield + %49 = tt.splat %arg7 : i32 -> tensor<1x4xi32, #blocked> + %50 = arith.addi %49, %3 : tensor<1x4xi32, #blocked> + %51 = arith.cmpi slt, %50, %4 : tensor<1x4xi32, #blocked> + %52 = tt.addptr %5, %50 : tensor<1x4x!tt.ptr, #blocked>, tensor<1x4xi32, #blocked> + %53 = ttg.convert_layout %52 : tensor<1x4x!tt.ptr, #blocked> -> tensor<1x4x!tt.ptr, #blocked> + %54 = ttg.convert_layout %51 : tensor<1x4xi1, #blocked> -> tensor<1x4xi1, #blocked> + %55 = ttg.convert_layout %cst_3 : tensor<1x4xf32, #blocked> -> tensor<1x4xf32, #blocked> + %56 = tt.load %53, %54, %55 evictionPolicy = evict_first : tensor<1x4x!tt.ptr, #blocked> + %57 = arith.addf %arg8, %56 : tensor<1x4xf32, #blocked> + %58 = arith.select %51, %57, %arg8 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %59 = arith.cmpf ogt, %arg9, %56 : tensor<1x4xf32, #blocked> + %60 = arith.cmpf une, %arg9, %arg9 : tensor<1x4xf32, #blocked> + %61 = arith.ori %59, %60 : tensor<1x4xi1, #blocked> + %62 = arith.select %61, %arg9, %56 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %63 = arith.select %51, %62, %arg9 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %64 = arith.cmpf olt, %arg10, %56 : tensor<1x4xf32, #blocked> + %65 = arith.cmpf une, %arg10, %arg10 : tensor<1x4xf32, #blocked> + %66 = arith.ori %64, %65 : tensor<1x4xi1, #blocked> + %67 = arith.select %66, %arg10, %56 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %68 = arith.select %51, %67, %arg10 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %69 = arith.cmpf ogt, %arg11, %56 : tensor<1x4xf32, #blocked> + %70 = arith.cmpf oeq, %arg11, %56 : tensor<1x4xf32, #blocked> + %71 = arith.cmpf une, %arg11, %arg11 : tensor<1x4xf32, #blocked> + %72 = arith.cmpf une, %56, %56 : tensor<1x4xf32, #blocked> + %73 = arith.xori %72, %cst : tensor<1x4xi1, #blocked> + %74 = arith.andi %71, %73 : tensor<1x4xi1, #blocked> + %75 = arith.ori %69, %74 : tensor<1x4xi1, #blocked> + %76 = arith.andi %71, %72 : tensor<1x4xi1, #blocked> + %77 = arith.ori %70, %76 : tensor<1x4xi1, #blocked> + %78 = arith.cmpi slt, %arg12, %50 : tensor<1x4xi32, #blocked> + %79 = arith.andi %77, %78 : tensor<1x4xi1, #blocked> + %80 = arith.ori %75, %79 : tensor<1x4xi1, #blocked> + %81 = arith.select %80, %arg11, %56 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %82 = arith.select %80, %arg12, %50 : tensor<1x4xi1, #blocked>, tensor<1x4xi32, #blocked> + %83 = arith.select %51, %81, %arg11 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %84 = arith.select %51, %82, %arg12 : tensor<1x4xi1, #blocked>, tensor<1x4xi32, #blocked> + %85 = arith.cmpf olt, %arg13, %56 : tensor<1x4xf32, #blocked> + %86 = arith.cmpf oeq, %arg13, %56 : tensor<1x4xf32, #blocked> + %87 = arith.cmpf une, %arg13, %arg13 : tensor<1x4xf32, #blocked> + %88 = arith.andi %87, %73 : tensor<1x4xi1, #blocked> + %89 = arith.ori %85, %88 : tensor<1x4xi1, #blocked> + %90 = arith.andi %87, %72 : tensor<1x4xi1, #blocked> + %91 = arith.ori %86, %90 : tensor<1x4xi1, #blocked> + %92 = arith.cmpi slt, %arg14, %50 : tensor<1x4xi32, #blocked> + %93 = arith.andi %91, %92 : tensor<1x4xi1, #blocked> + %94 = arith.ori %89, %93 : tensor<1x4xi1, #blocked> + %95 = arith.select %94, %arg13, %56 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %96 = arith.select %94, %arg14, %50 : tensor<1x4xi1, #blocked>, tensor<1x4xi32, #blocked> + %97 = arith.select %51, %95, %arg13 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %98 = arith.select %51, %96, %arg14 : tensor<1x4xi1, #blocked>, tensor<1x4xi32, #blocked> + scf.yield %58, %63, %68, %83, %84, %97, %98 : tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked> + } + // CHECK: [[RED0:%.*]] = "tt.reduce"([[RES]]#0) + %7 = "tt.reduce"(%6#0) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: f32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %49 = arith.addf %arg7, %arg8 : f32 + tt.reduce.return %49 : f32 + }) : (tensor<1x4xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED0]] + %8 = ttg.convert_layout %7 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1xf32, #blocked3> + %11 = ttg.convert_layout %10 : tensor<1x1xf32, #blocked3> -> tensor<1x1xf32, #blocked4> + // CHECK: [[RED1:%.*]] = "tt.reduce"([[RES]]#1) + %12 = "tt.reduce"(%6#1) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: f32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %49 = arith.cmpf ogt, %arg7, %arg8 : f32 + %50 = arith.cmpf une, %arg7, %arg7 : f32 + %51 = arith.ori %49, %50 : i1 + %52 = arith.select %51, %arg7, %arg8 : f32 + tt.reduce.return %52 : f32 + }) : (tensor<1x4xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED1]] + %13 = ttg.convert_layout %12 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %14 = ttg.convert_layout %13 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %15 = tt.expand_dims %14 {axis = 1 : i32} : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1xf32, #blocked3> + %16 = ttg.convert_layout %15 : tensor<1x1xf32, #blocked3> -> tensor<1x1xf32, #blocked4> + // CHECK: [[RED2:%.*]] = "tt.reduce"([[RES]]#2) + %17 = "tt.reduce"(%6#2) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: f32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %49 = arith.cmpf olt, %arg7, %arg8 : f32 + %50 = arith.cmpf une, %arg7, %arg7 : f32 + %51 = arith.ori %49, %50 : i1 + %52 = arith.select %51, %arg7, %arg8 : f32 + tt.reduce.return %52 : f32 + }) : (tensor<1x4xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED2]] + %18 = ttg.convert_layout %17 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %19 = ttg.convert_layout %18 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %20 = tt.expand_dims %19 {axis = 1 : i32} : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1xf32, #blocked3> + %21 = ttg.convert_layout %20 : tensor<1x1xf32, #blocked3> -> tensor<1x1xf32, #blocked4> + // CHECK: [[RED3:%.*]]:2 = "tt.reduce"([[RES]]#3, [[RES]]#4) + %22:2 = "tt.reduce"(%6#3, %6#4) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: i32, %arg9: f32, %arg10: i32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %49 = arith.cmpf ogt, %arg7, %arg9 : f32 + %50 = arith.cmpf oeq, %arg7, %arg9 : f32 + %51 = arith.cmpf une, %arg7, %arg7 : f32 + %52 = arith.cmpf une, %arg9, %arg9 : f32 + %53 = arith.xori %52, %true : i1 + %54 = arith.andi %51, %53 : i1 + %55 = arith.ori %49, %54 : i1 + %56 = arith.andi %51, %52 : i1 + %57 = arith.ori %50, %56 : i1 + %58 = arith.cmpi slt, %arg8, %arg10 : i32 + %59 = arith.andi %57, %58 : i1 + %60 = arith.ori %55, %59 : i1 + %61 = arith.select %60, %arg7, %arg9 : f32 + %62 = arith.select %60, %arg8, %arg10 : i32 + tt.reduce.return %61, %62 : f32, i32 + }) : (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) -> (tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED3]]#1 + %23 = ttg.convert_layout %22#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %24 = ttg.convert_layout %23 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %25 = tt.expand_dims %24 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1xi32, #blocked3> + %26 = ttg.convert_layout %25 : tensor<1x1xi32, #blocked3> -> tensor<1x1xi32, #blocked4> + // CHECK: [[RED4:%.*]]:2 = "tt.reduce"([[RES]]#5, [[RES]]#6) + %27:2 = "tt.reduce"(%6#5, %6#6) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: i32, %arg9: f32, %arg10: i32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %49 = arith.cmpf olt, %arg7, %arg9 : f32 + %50 = arith.cmpf oeq, %arg7, %arg9 : f32 + %51 = arith.cmpf une, %arg7, %arg7 : f32 + %52 = arith.cmpf une, %arg9, %arg9 : f32 + %53 = arith.xori %52, %true : i1 + %54 = arith.andi %51, %53 : i1 + %55 = arith.ori %49, %54 : i1 + %56 = arith.andi %51, %52 : i1 + %57 = arith.ori %50, %56 : i1 + %58 = arith.cmpi slt, %arg8, %arg10 : i32 + %59 = arith.andi %57, %58 : i1 + %60 = arith.ori %55, %59 : i1 + %61 = arith.select %60, %arg7, %arg9 : f32 + %62 = arith.select %60, %arg8, %arg10 : i32 + tt.reduce.return %61, %62 : f32, i32 + }) : (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) -> (tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED4]]#1 + %28 = ttg.convert_layout %27#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %29 = ttg.convert_layout %28 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %30 = tt.expand_dims %29 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1xi32, #blocked3> + %31 = ttg.convert_layout %30 : tensor<1x1xi32, #blocked3> -> tensor<1x1xi32, #blocked4> + %32 = tt.splat %arg1 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked4> + %33 = ttg.convert_layout %32 : tensor<1x1x!tt.ptr, #blocked4> -> tensor<1x1x!tt.ptr, #blocked4> + %34 = ttg.convert_layout %11 : tensor<1x1xf32, #blocked4> -> tensor<1x1xf32, #blocked4> + tt.store %33, %34 : tensor<1x1x!tt.ptr, #blocked4> + %35 = tt.splat %arg2 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked4> + %36 = ttg.convert_layout %35 : tensor<1x1x!tt.ptr, #blocked4> -> tensor<1x1x!tt.ptr, #blocked4> + %37 = ttg.convert_layout %16 : tensor<1x1xf32, #blocked4> -> tensor<1x1xf32, #blocked4> + tt.store %36, %37 : tensor<1x1x!tt.ptr, #blocked4> + %38 = tt.splat %arg3 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked4> + %39 = ttg.convert_layout %38 : tensor<1x1x!tt.ptr, #blocked4> -> tensor<1x1x!tt.ptr, #blocked4> + %40 = ttg.convert_layout %21 : tensor<1x1xf32, #blocked4> -> tensor<1x1xf32, #blocked4> + tt.store %39, %40 : tensor<1x1x!tt.ptr, #blocked4> + %41 = tt.splat %arg4 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked4> + %42 = arith.extsi %26 : tensor<1x1xi32, #blocked4> to tensor<1x1xi64, #blocked4> + %43 = ttg.convert_layout %41 : tensor<1x1x!tt.ptr, #blocked4> -> tensor<1x1x!tt.ptr, #blocked4> + %44 = ttg.convert_layout %42 : tensor<1x1xi64, #blocked4> -> tensor<1x1xi64, #blocked4> + tt.store %43, %44 : tensor<1x1x!tt.ptr, #blocked4> + %45 = tt.splat %arg5 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked4> + %46 = arith.extsi %31 : tensor<1x1xi32, #blocked4> to tensor<1x1xi64, #blocked4> + %47 = ttg.convert_layout %45 : tensor<1x1x!tt.ptr, #blocked4> -> tensor<1x1x!tt.ptr, #blocked4> + %48 = ttg.convert_layout %46 : tensor<1x1xi64, #blocked4> -> tensor<1x1xi64, #blocked4> + tt.store %47, %48 : tensor<1x1x!tt.ptr, #blocked4> + tt.return + } +} From 09bda3da0982b71114f98d98c147899f64ca30e5 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 17 Oct 2025 18:04:26 +0000 Subject: [PATCH 7/7] Reduce diff Signed-off-by: Ettore Tiotto --- .../RemoveLayoutConversions.cpp | 95 ++++--------------- 1 file changed, 21 insertions(+), 74 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index f4467cd17c..bb7a7afb33 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -1,13 +1,10 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" -#include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -29,7 +26,6 @@ #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Tools/Sys/GetEnv.hpp" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include @@ -1213,43 +1209,13 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, llvm::errs().indent(2) << "(" << op << "): "; op->dumpPretty(); llvm::errs() << "\n"; - -#if 0 - llvm::errs() << "mapping:\n"; - if (mapping.getValueMap().values().empty()) - llvm::errs().indent(2) << "Values: empty" << "\n"; - else { - llvm::errs().indent(2) << "Values:\n"; - for (auto pair : mapping.getValueMap()) { - auto first = pair.first; - auto second = pair.second; - llvm::errs().indent(4); - first.printAsOperand(llvm::errs(), {}); - llvm::errs() << " -> "; - second.printAsOperand(llvm::errs(), {}); - llvm::errs() << "\n"; - } - } - if (mapping.getOperationMap().values().empty()) - llvm::errs().indent(2) << "Operations: empty" << "\n"; - else { - llvm::errs().indent(2) << "Operations:\n"; - for (auto pair : mapping.getOperationMap()) { - auto first = pair.first; - auto second = pair.second; - llvm::errs().indent(4) << *first << "\n"; - llvm::errs().indent(4) << " -> " << *second << "\n"; - } - } - // assert(mapping.getBlockMap().values().empty()); -#endif }); if (auto forOp = dyn_cast(op)) { SmallVector newOperands; if (enableForLoopSupport) { - // Construct the new initialization argument by adding yielded - // operands that have been remapped. + // Construct the new initialization argument by adding yielded operands + // that have been remapped. auto yieldOp = cast(forOp.getBody()->getTerminator()); auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); SmallVector operandsToRewrite = yieldOperandsMap[yieldOp]; @@ -1350,7 +1316,6 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, } continue; } - if (auto ifOp = dyn_cast(op)) { SmallVector newTypes; for (auto res : ifOp.getResults()) { @@ -1388,7 +1353,6 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, deadOps.push_back(ifOp.getOperation()); continue; } - builder.setInsertionPoint(op); if (auto yieldOp = dyn_cast(op)) { auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); @@ -1532,8 +1496,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, addRematValue(old, it->second, newV); } } - // Check mapping and see if there are existing convertOps on the old - // Argument + // Check mapping and see if there are existing convertOps on the old Argument convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc())); opToDelete.insert(convertOp); @@ -1626,21 +1589,6 @@ void LayoutRematerialization::backwardRematerialization() { addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), convertOp.getResult()); } - -#if 0 - for (Operation *op : llvm::reverse(opToDelete)) { - if (op == convertOp) - continue; - if (op->getUsers().empty()) - op->erase(); - } - - auto mod = convertOp->getParentOfType(); - llvm::errs() << "mod:\n"; - mod->dumpPretty(); - llvm::errs() << "\n"; - assert(succeeded(verify(mod)) && "Module verification failed"); -#endif } reduceLoopCarriedValues(); @@ -1724,8 +1672,8 @@ void LayoutRematerialization::backwardRematerialization( Value oldV = convertOp.getSrc(); LDBG("check backward remat with source " << oldV << " encoding " << targetType.getEncoding()); - // Check to see if there are existing remat'ed values for the pair of - // oldValue and encoding. Make sure it dominates the current conversion. + // Check to see if there are existing remat'ed values for the pair of oldValue + // and encoding. Make sure it dominates the current conversion. Value newV = getRematValue(oldV, targetType.getEncoding()); if (newV && domInfo.properlyDominates(newV, convertOp)) { // Replace it with the remat'ed value. @@ -1837,8 +1785,8 @@ void LayoutRematerialization::backwardRematerialization( auto reduceOp = dyn_cast(op); ReduceOpHelper helper(reduceOp); if (!helper.isAssociative()) { - // We shouldn't rematerize a no associative reduce op if it has - // multiple use chain. + // We shouldn't rematerize a no associative reduce op if it has multiple + // use chain. LDBG(" skipped rematerialization due to non-associative reduce in " "the " "slice"); @@ -1924,14 +1872,14 @@ void LayoutRematerialization::hoistConvertDotOperand( if (!canBePipelined(convertOp)) return; - // We hoist over any operation that can be done without data movement - // between threads We do views and elementwise pure ops for now + // We hoist over any operation that can be done without data movement between + // threads We do views and elementwise pure ops for now auto noDataMovement = [](Operation *op) { return (op->hasTrait() && isMemoryEffectFree(op)) || isa(op) || isView(op); }; - // Stop the slice as soon as we find an operation that cannot be done - // without data movement between threads + // Stop the slice as soon as we find an operation that cannot be done without + // data movement between threads auto stop = std::not_fn(noDataMovement); SetVector slice; @@ -1991,8 +1939,8 @@ void LayoutRematerialization::hoistConvertDotOperand( rewriteSlice(innerSlice, layout, convertOp, mapping); } -// For convert left we try to hoist them above type extension to reduce the -// cost of the convert. +// For convert left we try to hoist them above type extension to reduce the cost +// of the convert. void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { // DotOperand is hoisted by hoistDotOperand @@ -2094,8 +2042,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( void LayoutRematerialization::hoistConvertIntoConditionals( ConvertLayoutOp convertOp) { // Take the backward slice of tensor dependencies rooted at the conversion, - // stopping at conditionals. This subslice is used to initialize the - // analysis. + // stopping at conditionals. This subslice is used to initialize the analysis. SetVector slice; DenseMap layout; auto isIfOp = [](Operation *op) { return isa(op); }; @@ -2104,17 +2051,17 @@ void LayoutRematerialization::hoistConvertIntoConditionals( layout, isIfOp))) return; - // These are the conditional edges above which conversions should be - // hoisted. The value represents the `scf.if` op result and the operand - // represents the edge into one of the branches. + // These are the conditional edges above which conversions should be hoisted. + // The value represents the `scf.if` op result and the operand/ represents the + // edge into one of the branches. SmallVector> hoistAbove; - // The list of `scf.if` op results in the slice that are not - // rematerializable. Hoisting is terminated at these values. + // The list of `scf.if` op results in the slice that are not rematerializable. + // Hoisting is terminated at these values. SmallVector terminals; - // This loop recurses through the subslices of the backwards dependencies, - // so re-query the size of `slice`. + // This loop recurses through the subslices of the backwards dependencies, so + // re-query the size of `slice`. for (unsigned i = 0; i != slice.size(); ++i) { Value v = slice[i]; auto ifOp = v.getDefiningOp();