diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h index 9ab7baaa71..c8b0749565 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h @@ -50,7 +50,8 @@ LogicalResult getConvertBackwardSlice( DenseMap &layout, std::function stopPropagation = nullptr, std::function getExistingConversion = - nullptr); + nullptr, + bool includeForOp = false); LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef paramTypes, diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index f734bdf88d..ea0c450159 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -149,19 +149,22 @@ class LayoutRematerialization { getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, - std::function stopPropagation); + std::function stopPropagation, + bool includeForOp = false); LogicalResult getRematerializableSlice( OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, - std::function stopPropagation = nullptr); + std::function stopPropagation = nullptr, + bool includeForOp = false); private: void updateRematMapping(SmallVector> &values); // Existing tuples of (value, layout) that needs to be updated when recreating // scf ops. This prevents keeping track of Values that have been delete when - // rewriting slices. - DenseMap mappedValues; + // rewriting slices. The Value maybe mapped to different attributes in remove + // layout. + DenseMap> mappedValues; // map of the values remat based on encoding. DenseMap, Value> rematMapping; // DenseMap, Operation*> @@ -174,7 +177,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding, Value newV) { LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); rematMapping[{old, encoding}] = newV; - mappedValues[old] = encoding; + if (mappedValues.contains(old)) { + mappedValues[old].push_back(encoding); + } else { + mappedValues[old] = {encoding}; + } } // Remove unneeded values now that we are done with the rematMapping. @@ -955,22 +962,28 @@ void LayoutRematerialization::updateRematMapping( for (auto [old, newV] : values) { auto it = mappedValues.find(old); if (it != mappedValues.end()) { - Attribute encoding = it->second; - auto rematIt = rematMapping.find({old, it->second}); - assert(rematIt != rematMapping.end()); - Value replacedValue = rematIt->second; - rematMapping.erase(rematIt); - mappedValues.erase(it); - // Loop through the replacement value to find the new version of remat - // value. This should be okay as the number of values should be small. - for (auto [before, after] : values) { - if (before == replacedValue) { - replacedValue = after; - break; + SmallVector encodings = it->second; + for (auto encoding : encodings) { + auto rematIt = rematMapping.find({old, encoding}); + assert(rematIt != rematMapping.end()); + Value replacedValue = rematIt->second; + rematMapping.erase(rematIt); + // Loop through the replacement value to find the new version of remat + // value. This should be okay as the number of values should be small. + for (auto [before, after] : values) { + if (before == replacedValue) { + replacedValue = after; + break; + } } + rematMapping[{newV, encoding}] = replacedValue; + } + mappedValues.erase(it); + if (mappedValues.contains(newV)) { + mappedValues[newV].append(encodings); + } else { + mappedValues[newV] = std::move(encodings); } - rematMapping[{newV, encoding}] = replacedValue; - mappedValues[newV] = encoding; } } } @@ -1045,6 +1058,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, deadOps.push_back(forOp.getOperation()); Block &loopBody = *newForOp.getBody(); for (auto m : argMapping) { + mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second)); mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second)); int numIndVars = newForOp.getNumInductionVars(); mapping.map(loopBody.getArgument(m.first + numIndVars), @@ -1161,8 +1175,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); } - for (Operation *op : deadOps) - opToDelete.insert(op); + for (Operation *op : deadOps) { + if (!isa(op)) + opToDelete.insert(op); + else + op->erase(); + } } void LayoutRematerialization::rewriteSlice(SetVector &slice, @@ -1175,7 +1193,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, LogicalResult LayoutRematerialization::getConvertBackwardSlice( OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, - std::function stopPropagation) { + std::function stopPropagation, bool includeForOp) { // Allow re-using existing conversions for a value. Check dominance of any // reusable materializations against the root value. This is sufficient // because the conversions are processed in post-order. @@ -1204,15 +1222,18 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice( }; return ttgi::getConvertBackwardSlice(root, slice, rootEncoding, layout, - stopPropagation, getExistingConversion); + stopPropagation, getExistingConversion, + includeForOp); } LogicalResult LayoutRematerialization::getRematerializableSlice( OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, - std::function stopPropagation) { - LogicalResult result = getConvertBackwardSlice( - root, rootEncoding, slice, layout, std::move(stopPropagation)); + std::function stopPropagation, bool includeForOp) { + + LogicalResult result = + getConvertBackwardSlice(root, rootEncoding, slice, layout, + std::move(stopPropagation), includeForOp); if (result.failed() || slice.empty()) return failure(); @@ -1301,8 +1322,9 @@ void LayoutRematerialization::backwardRematerialization( // rematerialized. SetVector slice; DenseMap layout; - LogicalResult result = getRematerializableSlice( - convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout); + LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(), + targetType.getEncoding(), + slice, layout, nullptr, true); if (result.failed()) { LDBG(" getRematerializableSlice failed"); return; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 785920770a..d9e263549d 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -182,7 +182,8 @@ LogicalResult getConvertBackwardSlice( OpOperand &root, SetVector &slice, Attribute rootEncoding, DenseMap &layout, std::function stopPropagation, - std::function getExistingConversion) { + std::function getExistingConversion, + bool includeForOp) { DenseSet> seen; SmallVector> queue; @@ -197,6 +198,12 @@ LogicalResult getConvertBackwardSlice( auto updateLayout = [&](Value value, Attribute encoding) { assert(isTensorOrTensorPointerType(value.getType())); + auto tensorType = getRankedTensorType(value.getType()); + auto originEncoding = tensorType.getEncoding(); + if (originEncoding == encoding) { + return success(); + } + slice.insert(value); Attribute &existing = layout[value]; if (existing && existing != encoding) @@ -213,8 +220,38 @@ LogicalResult getConvertBackwardSlice( continue; // Skip propagating through for op results for now. // TODO: enable this based on needs. - if (currentValue.getDefiningOp()) + if (auto forOp = currentValue.getDefiningOp()) { + if (!includeForOp) + return failure(); + if (stopPropagation && stopPropagation(forOp)) + continue; + unsigned argIdx = mlir::cast(currentValue).getResultNumber(); + int numIndVars = forOp.getNumInductionVars(); + Block &loopBody = *forOp.getBody(); + auto blockArg = loopBody.getArgument(argIdx + numIndVars); + + Value existing; + if (getExistingConversion && + (existing = getExistingConversion(*currentValueUse, encoding))) { + if (failed(updateLayout(currentValue, encoding))) + return failure(); + + continue; + } return failure(); + + OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); + OpOperand &yieldOperand = loopBody.getTerminator()->getOpOperand(argIdx); + llvm::outs() << "johnlu getBackward slice check scf.for initOperand: " + << initOperand->get() << "\n"; + llvm::outs() << "johnlu getBackward slice check scf.for yieldOperand: " + << yieldOperand.get() << "\n"; + if (failed(updateLayout(blockArg, encoding))) + return failure(); + enqueue(*initOperand, encoding); + enqueue(yieldOperand, encoding); + continue; + } if (failed(updateLayout(currentValue, encoding))) return failure();