|
1 | | -//===- FuncTransformOps.cpp - Implementation of CF transform ops ---===// |
| 1 | +//===- FuncTransformOps.cpp - Implementation of CF transform ops ----------===// |
2 | 2 | // |
3 | 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | 4 | // See https://llvm.org/LICENSE.txt for license information. |
|
11 | 11 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
12 | 12 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
13 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 14 | +#include "mlir/Dialect/Func/Utils/Utils.h" |
14 | 15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
15 | 16 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
16 | 17 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
17 | 18 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 19 | +#include "mlir/IR/PatternMatch.h" |
18 | 20 | #include "mlir/Transforms/DialectConversion.h" |
19 | 21 |
|
20 | 22 | using namespace mlir; |
@@ -226,6 +228,109 @@ void transform::CastAndCallOp::getEffects( |
226 | 228 | transform::modifiesPayload(effects); |
227 | 229 | } |
228 | 230 |
|
| 231 | +//===----------------------------------------------------------------------===// |
| 232 | +// ReplaceFuncSignatureOp |
| 233 | +//===----------------------------------------------------------------------===// |
| 234 | + |
| 235 | +DiagnosedSilenceableFailure |
| 236 | +transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter, |
| 237 | + transform::TransformResults &results, |
| 238 | + transform::TransformState &state) { |
| 239 | + auto payloadOps = state.getPayloadOps(getModule()); |
| 240 | + if (!llvm::hasSingleElement(payloadOps)) |
| 241 | + return emitDefiniteFailure() << "requires a single module to operate on"; |
| 242 | + |
| 243 | + auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin()); |
| 244 | + if (!targetModuleOp) |
| 245 | + return emitSilenceableFailure(getLoc()) |
| 246 | + << "target is expected to be module operation"; |
| 247 | + |
| 248 | + func::FuncOp funcOp = |
| 249 | + targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName()); |
| 250 | + if (!funcOp) |
| 251 | + return emitSilenceableFailure(getLoc()) |
| 252 | + << "function with name '" << getFunctionName() << "' not found"; |
| 253 | + |
| 254 | + unsigned numArgs = funcOp.getNumArguments(); |
| 255 | + unsigned numResults = funcOp.getNumResults(); |
| 256 | + // Check that the number of arguments and results matches the |
| 257 | + // interchange sizes. |
| 258 | + if (numArgs != getArgsInterchange().size()) |
| 259 | + return emitSilenceableFailure(getLoc()) |
| 260 | + << "function with name '" << getFunctionName() << "' has " << numArgs |
| 261 | + << " arguments, but " << getArgsInterchange().size() |
| 262 | + << " args interchange were given"; |
| 263 | + |
| 264 | + if (numResults != getResultsInterchange().size()) |
| 265 | + return emitSilenceableFailure(getLoc()) |
| 266 | + << "function with name '" << getFunctionName() << "' has " |
| 267 | + << numResults << " results, but " << getResultsInterchange().size() |
| 268 | + << " results interchange were given"; |
| 269 | + |
| 270 | + // Check that the args and results interchanges are unique. |
| 271 | + SetVector<unsigned> argsInterchange, resultsInterchange; |
| 272 | + argsInterchange.insert_range(getArgsInterchange()); |
| 273 | + resultsInterchange.insert_range(getResultsInterchange()); |
| 274 | + if (argsInterchange.size() != getArgsInterchange().size()) |
| 275 | + return emitSilenceableFailure(getLoc()) |
| 276 | + << "args interchange must be unique"; |
| 277 | + |
| 278 | + if (resultsInterchange.size() != getResultsInterchange().size()) |
| 279 | + return emitSilenceableFailure(getLoc()) |
| 280 | + << "results interchange must be unique"; |
| 281 | + |
| 282 | + // Check that the args and results interchange indices are in bounds. |
| 283 | + for (unsigned index : argsInterchange) { |
| 284 | + if (index >= numArgs) { |
| 285 | + return emitSilenceableFailure(getLoc()) |
| 286 | + << "args interchange index " << index |
| 287 | + << " is out of bounds for function with name '" |
| 288 | + << getFunctionName() << "' with " << numArgs << " arguments"; |
| 289 | + } |
| 290 | + } |
| 291 | + for (unsigned index : resultsInterchange) { |
| 292 | + if (index >= numResults) { |
| 293 | + return emitSilenceableFailure(getLoc()) |
| 294 | + << "results interchange index " << index |
| 295 | + << " is out of bounds for function with name '" |
| 296 | + << getFunctionName() << "' with " << numResults << " results"; |
| 297 | + } |
| 298 | + } |
| 299 | + |
| 300 | + FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder( |
| 301 | + rewriter, funcOp, argsInterchange.getArrayRef(), |
| 302 | + resultsInterchange.getArrayRef()); |
| 303 | + if (failed(newFuncOpOrFailure)) |
| 304 | + return emitSilenceableFailure(getLoc()) |
| 305 | + << "failed to replace function signature '" << getFunctionName() |
| 306 | + << "' with new order"; |
| 307 | + |
| 308 | + if (getAdjustFuncCalls()) { |
| 309 | + SmallVector<func::CallOp> callOps; |
| 310 | + targetModuleOp.walk([&](func::CallOp callOp) { |
| 311 | + if (callOp.getCallee() == getFunctionName().getRootReference().getValue()) |
| 312 | + callOps.push_back(callOp); |
| 313 | + }); |
| 314 | + |
| 315 | + for (func::CallOp callOp : callOps) |
| 316 | + func::replaceCallOpWithNewOrder(rewriter, callOp, |
| 317 | + argsInterchange.getArrayRef(), |
| 318 | + resultsInterchange.getArrayRef()); |
| 319 | + } |
| 320 | + |
| 321 | + results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp}); |
| 322 | + results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure}); |
| 323 | + |
| 324 | + return DiagnosedSilenceableFailure::success(); |
| 325 | +} |
| 326 | + |
| 327 | +void transform::ReplaceFuncSignatureOp::getEffects( |
| 328 | + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 329 | + transform::consumesHandle(getModuleMutable(), effects); |
| 330 | + transform::producesHandle(getOperation()->getOpResults(), effects); |
| 331 | + transform::modifiesPayload(effects); |
| 332 | +} |
| 333 | + |
229 | 334 | //===----------------------------------------------------------------------===// |
230 | 335 | // Transform op registration |
231 | 336 | //===----------------------------------------------------------------------===// |
|
0 commit comments