|
| 1 | +#define DEBUG_TYPE "detensorize-func-boundary" |
| 2 | + |
| 3 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 4 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 5 | +#include "mlir/IR/IRMapping.h" |
| 6 | +#include "mlir/IR/PatternMatch.h" |
| 7 | +#include "mlir/Pass/Pass.h" |
| 8 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 9 | + |
| 10 | +#include "Catalyst/IR/CatalystDialect.h" |
| 11 | + |
| 12 | +using namespace llvm; |
| 13 | +using namespace mlir; |
| 14 | +using namespace catalyst; |
| 15 | + |
| 16 | +namespace { |
| 17 | +bool isScalarTensor(Type type) |
| 18 | +{ |
| 19 | + if (auto rankedType = dyn_cast<RankedTensorType>(type)) { |
| 20 | + return rankedType.getRank() == 0; |
| 21 | + } |
| 22 | + return false; |
| 23 | +} |
| 24 | + |
| 25 | +Type getScalarOrOriginalType(Type type) |
| 26 | +{ |
| 27 | + if (isScalarTensor(type)) { |
| 28 | + return dyn_cast<RankedTensorType>(type).getElementType(); |
| 29 | + } |
| 30 | + else { |
| 31 | + return type; |
| 32 | + } |
| 33 | +} |
| 34 | + |
| 35 | +bool hasScalarTensorSignature(func::FuncOp funcOp) |
| 36 | +{ |
| 37 | + for (Type type : funcOp.getFunctionType().getInputs()) { |
| 38 | + if (isScalarTensor(type)) { |
| 39 | + return true; |
| 40 | + } |
| 41 | + } |
| 42 | + for (Type type : funcOp.getFunctionType().getResults()) { |
| 43 | + if (isScalarTensor(type)) { |
| 44 | + return true; |
| 45 | + } |
| 46 | + } |
| 47 | + return false; |
| 48 | +} |
| 49 | + |
| 50 | +struct DetensorizeCallSitePattern : public OpRewritePattern<func::CallOp> { |
| 51 | + using OpRewritePattern<func::CallOp>::OpRewritePattern; |
| 52 | + |
| 53 | + LogicalResult matchAndRewrite(func::CallOp callOp, PatternRewriter &rewriter) const override |
| 54 | + { |
| 55 | + auto funcOp = |
| 56 | + SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(callOp, callOp.getCalleeAttr()); |
| 57 | + |
| 58 | + // Skip for main function |
| 59 | + if (!funcOp || funcOp->hasAttr("llvm.emit_c_interface")) { |
| 60 | + return failure(); |
| 61 | + } |
| 62 | + |
| 63 | + if (!hasScalarTensorSignature(funcOp)) { |
| 64 | + return failure(); |
| 65 | + } |
| 66 | + |
| 67 | + // Skip for QNodes |
| 68 | + // Some Gradient boundaries only work for Tensor signatures |
| 69 | + // and not scalar ones, hence we skip them here. |
| 70 | + if (funcOp->hasAttr("qnode")) { |
| 71 | + return failure(); |
| 72 | + } |
| 73 | + |
| 74 | + // Create detensorized FuncOp if it does not already exist |
| 75 | + auto module = callOp->getParentOfType<ModuleOp>(); |
| 76 | + std::string newFuncName = funcOp.getName().str() + ".detensorized"; |
| 77 | + auto newFuncOp = module.lookupSymbol<func::FuncOp>(newFuncName); |
| 78 | + |
| 79 | + if (!newFuncOp) { |
| 80 | + OpBuilder::InsertionGuard guard(rewriter); |
| 81 | + rewriter.setInsertionPointToEnd(module.getBody()); |
| 82 | + |
| 83 | + // Create the new function with a detensorized signature |
| 84 | + FunctionType funcType = funcOp.getFunctionType(); |
| 85 | + SmallVector<Type> newArgTypes, newResultTypes; |
| 86 | + SmallVector<NamedAttribute> newAttrs; |
| 87 | + extractDetensorizedOpSignature(funcType, funcOp, newArgTypes, newResultTypes, newAttrs); |
| 88 | + |
| 89 | + // Create the new function, passing the collected signature |
| 90 | + auto newFuncType = FunctionType::get(getContext(), newArgTypes, newResultTypes); |
| 91 | + newFuncOp = |
| 92 | + rewriter.create<func::FuncOp>(funcOp.getLoc(), newFuncName, newFuncType, newAttrs); |
| 93 | + |
| 94 | + // Map FuncOp body and return operation |
| 95 | + Block *newEntryBlock = newFuncOp.addEntryBlock(); |
| 96 | + IRMapping mapper; |
| 97 | + mapFuncOpBodyAndReturnOp(rewriter, newEntryBlock, funcOp, mapper); |
| 98 | + } |
| 99 | + |
| 100 | + // Rewrite the original call site to use the new detensorized function |
| 101 | + replaceCallOp(rewriter, callOp, newFuncOp); |
| 102 | + return success(); |
| 103 | + } |
| 104 | + |
| 105 | + void extractDetensorizedOpSignature(FunctionType &funcType, func::FuncOp &funcOp, |
| 106 | + SmallVector<Type> &newArgTypes, |
| 107 | + SmallVector<Type> &newResultTypes, |
| 108 | + SmallVector<NamedAttribute> &newAttrs) const |
| 109 | + { |
| 110 | + for (Type type : funcType.getInputs()) { |
| 111 | + newArgTypes.push_back(getScalarOrOriginalType(type)); |
| 112 | + } |
| 113 | + for (Type type : funcType.getResults()) { |
| 114 | + newResultTypes.push_back(getScalarOrOriginalType(type)); |
| 115 | + } |
| 116 | + |
| 117 | + // Collect all attributes from the original function |
| 118 | + for (const NamedAttribute &attr : funcOp->getAttrs()) { |
| 119 | + if (attr.getName() == funcOp.getSymNameAttrName() || |
| 120 | + attr.getName() == funcOp.getFunctionTypeAttrName()) { |
| 121 | + continue; |
| 122 | + } |
| 123 | + newAttrs.push_back(attr); |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + void mapFuncOpBodyAndReturnOp(PatternRewriter &rewriter, Block *newEntryBlock, |
| 128 | + func::FuncOp &funcOp, IRMapping &mapper) const |
| 129 | + { |
| 130 | + rewriter.setInsertionPointToStart(newEntryBlock); |
| 131 | + for (const auto &it : llvm::enumerate(funcOp.getArguments())) { |
| 132 | + Value oldArg = it.value(); |
| 133 | + Value newArg = newEntryBlock->getArgument(it.index()); |
| 134 | + |
| 135 | + if (isScalarTensor(oldArg.getType())) { |
| 136 | + // Insert a FromElementsOp if the old argument is a scalar tensor |
| 137 | + auto fromElementsOp = rewriter.create<tensor::FromElementsOp>( |
| 138 | + newArg.getLoc(), oldArg.getType(), newArg); |
| 139 | + mapper.map(oldArg, fromElementsOp.getResult()); |
| 140 | + } |
| 141 | + else { |
| 142 | + mapper.map(oldArg, newArg); |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + // Clone the operations from the body of old function (excluding the old return) |
| 147 | + rewriter.setInsertionPointToEnd(newEntryBlock); |
| 148 | + for (Operation &op : funcOp.front().without_terminator()) { |
| 149 | + rewriter.clone(op, mapper); |
| 150 | + } |
| 151 | + |
| 152 | + // Create a new return operation with the mapped results |
| 153 | + auto oldReturnOp = cast<func::ReturnOp>(funcOp.front().getTerminator()); |
| 154 | + SmallVector<Value> newReturnOperands; |
| 155 | + newReturnOperands.reserve(oldReturnOp.getNumOperands()); |
| 156 | + for (Value operand : oldReturnOp.getOperands()) { |
| 157 | + Value newOperand = mapper.lookup(operand); |
| 158 | + if (isScalarTensor(newOperand.getType())) { |
| 159 | + // Insert ExtractOp if the operand is a scalar tensor |
| 160 | + auto extractOp = rewriter.create<tensor::ExtractOp>(oldReturnOp.getLoc(), |
| 161 | + newOperand, ValueRange{}); |
| 162 | + newReturnOperands.push_back(extractOp.getResult()); |
| 163 | + } |
| 164 | + else { |
| 165 | + newReturnOperands.push_back(newOperand); |
| 166 | + } |
| 167 | + } |
| 168 | + rewriter.create<func::ReturnOp>(oldReturnOp.getLoc(), newReturnOperands); |
| 169 | + } |
| 170 | + |
| 171 | + void replaceCallOp(PatternRewriter &rewriter, func::CallOp &callOp, |
| 172 | + func::FuncOp &newFuncOp) const |
| 173 | + { |
| 174 | + rewriter.setInsertionPoint(callOp); |
| 175 | + SmallVector<Value> newOperands; |
| 176 | + for (Value operand : callOp.getOperands()) { |
| 177 | + // Insert ExtractOp if the old operand is a scalar tensor to bridge the detensorized |
| 178 | + // function |
| 179 | + if (isScalarTensor(operand.getType())) { |
| 180 | + auto extractOp = |
| 181 | + rewriter.create<tensor::ExtractOp>(callOp.getLoc(), operand, ValueRange{}); |
| 182 | + newOperands.push_back(extractOp.getResult()); |
| 183 | + } |
| 184 | + else { |
| 185 | + newOperands.push_back(operand); |
| 186 | + } |
| 187 | + } |
| 188 | + |
| 189 | + auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), newFuncOp, newOperands); |
| 190 | + |
| 191 | + SmallVector<Value> newResults; |
| 192 | + for (size_t i = 0; i < callOp.getNumResults(); ++i) { |
| 193 | + Value oldResult = callOp.getResult(i); |
| 194 | + Value newResult = newCallOp.getResult(i); |
| 195 | + if (isScalarTensor(oldResult.getType())) { |
| 196 | + // Insert a FromElementsOp if the old result is a scalar tensor to bridge the |
| 197 | + // detensorized function |
| 198 | + auto fromElementsOp = rewriter.create<tensor::FromElementsOp>( |
| 199 | + callOp.getLoc(), oldResult.getType(), newResult); |
| 200 | + newResults.push_back(fromElementsOp.getResult()); |
| 201 | + } |
| 202 | + else { |
| 203 | + newResults.push_back(newResult); |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + rewriter.replaceOp(callOp, newResults); |
| 208 | + } |
| 209 | +}; |
| 210 | +} // namespace |
| 211 | + |
| 212 | +namespace catalyst { |
| 213 | +#define GEN_PASS_DEF_DETENSORIZEFUNCTIONBOUNDARYPASS |
| 214 | +#include "Catalyst/Transforms/Passes.h.inc" |
| 215 | + |
| 216 | +struct DetensorizeFunctionBoundaryPass |
| 217 | + : public impl::DetensorizeFunctionBoundaryPassBase<DetensorizeFunctionBoundaryPass> { |
| 218 | + using impl::DetensorizeFunctionBoundaryPassBase< |
| 219 | + DetensorizeFunctionBoundaryPass>::DetensorizeFunctionBoundaryPassBase; |
| 220 | + void runOnOperation() override |
| 221 | + { |
| 222 | + MLIRContext *context = &getContext(); |
| 223 | + RewritePatternSet patterns(context); |
| 224 | + |
| 225 | + patterns.add<DetensorizeCallSitePattern>(context); |
| 226 | + |
| 227 | + GreedyRewriteConfig config; |
| 228 | + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), config))) { |
| 229 | + signalPassFailure(); |
| 230 | + } |
| 231 | + } |
| 232 | +}; |
| 233 | + |
| 234 | +std::unique_ptr<Pass> createDetensorizeFunctionBoundaryPass() |
| 235 | +{ |
| 236 | + return std::make_unique<DetensorizeFunctionBoundaryPass>(); |
| 237 | +} |
| 238 | + |
| 239 | +} // namespace catalyst |
0 commit comments