Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,30 @@ def ReplaceFuncSignatureOp
}];
}

def DeduplicateFuncArgsOp
: Op<Transform_Dialect, "func.deduplicate_func_args",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
This transform takes a module and a function name, and deduplicates
the arguments of the function. The function is expected to be defined in
the module.

This transform will emit a silenceable failure if:
- The function with the given name does not exist in the module.
- The function does not have duplicate arguments.
- The function does not have a single call.
}];

let arguments = (ins TransformHandleTypeInterface:$module,
SymbolRefAttr:$function_name);
let results = (outs TransformHandleTypeInterface:$transformed_module,
TransformHandleTypeInterface:$transformed_function);

let assemblyFormat = [{
$function_name
`at` $module attr-dict `:` functional-type(operands, results)
}];
}

#endif // FUNC_TRANSFORM_OPS
41 changes: 29 additions & 12 deletions mlir/include/mlir/Dialect/Func/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,49 @@

#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/ArrayRef.h"
#include <string>

namespace mlir {

class ModuleOp;

namespace func {

class FuncOp;
class CallOp;

/// Creates a new function operation with the same name as the original
/// function operation, but with the arguments reordered according to
/// the `newArgsOrder` and `newResultsOrder`.
/// function operation, but with the arguments mapped according to
/// the `oldArgToNewArg` and `oldResToNewRes`.
/// The `funcOp` operation must have exactly one block.
/// Returns the new function operation or failure if `funcOp` doesn't
/// have exactly one block.
FailureOr<FuncOp>
replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp,
llvm::ArrayRef<unsigned> newArgsOrder,
llvm::ArrayRef<unsigned> newResultsOrder);
/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
/// maps the whole function arguments and results.
mlir::FailureOr<mlir::func::FuncOp> replaceFuncWithNewMapping(
mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
ArrayRef<int> oldArgIdxToNewArgIdx, ArrayRef<int> oldResIdxToNewResIdx);
/// Creates a new call operation with the values as the original
/// call operation, but with the arguments reordered according to
/// the `newArgsOrder` and `newResultsOrder`.
CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp,
llvm::ArrayRef<unsigned> newArgsOrder,
llvm::ArrayRef<unsigned> newResultsOrder);
/// call operation, but with the arguments mapped according to
/// the `oldArgToNewArg` and `oldResToNewRes`.
/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
/// maps the whole call operation arguments and results.
mlir::func::CallOp replaceCallOpWithNewMapping(
mlir::RewriterBase &rewriter, mlir::func::CallOp callOp,
ArrayRef<int> oldArgIdxToNewArgIdx, ArrayRef<int> oldResIdxToNewResIdx);

/// This utility function examines all call operations within the given
/// `moduleOp` that target the specified `funcOp`. It identifies duplicate
/// operands in the call operations, creates mappings to deduplicate them, and
/// then applies the transformation to both the function and its call sites. For
/// now, it only supports one call operation for the function operation. The
/// function returns a pair containing the new funcOp and the new callOp. Note:
/// after the transformation, the original funcOp and callOp will be erased.
mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
mlir::ModuleOp moduleOp);

} // namespace func
} // namespace mlir

#endif // MLIR_DIALECT_FUNC_UTILS_H
#endif // MLIR_DIALECT_FUNC_UTILS_H
64 changes: 58 additions & 6 deletions mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"

using namespace mlir;

Expand Down Expand Up @@ -296,9 +297,16 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
}
}

FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
rewriter, funcOp, argsInterchange.getArrayRef(),
resultsInterchange.getArrayRef());
llvm::SmallVector<int> oldArgToNewArg(argsInterchange.size());
for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(argsInterchange))
oldArgToNewArg[oldArgIdx] = newArgIdx;

llvm::SmallVector<int> oldResToNewRes(resultsInterchange.size());
for (auto [newResIdx, oldResIdx] : llvm::enumerate(resultsInterchange))
oldResToNewRes[oldResIdx] = newResIdx;

FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
if (failed(newFuncOpOrFailure))
return emitSilenceableFailure(getLoc())
<< "failed to replace function signature '" << getFunctionName()
Expand All @@ -312,9 +320,8 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
});

for (func::CallOp callOp : callOps)
func::replaceCallOpWithNewOrder(rewriter, callOp,
argsInterchange.getArrayRef(),
resultsInterchange.getArrayRef());
func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgToNewArg,
oldResToNewRes);
}

results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
Expand All @@ -330,6 +337,51 @@ void transform::ReplaceFuncSignatureOp::getEffects(
transform::modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// DeduplicateFuncArgsOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::DeduplicateFuncArgsOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto payloadOps = state.getPayloadOps(getModule());
if (!llvm::hasSingleElement(payloadOps))
return emitDefiniteFailure() << "requires a single module to operate on";

auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
if (!targetModuleOp)
return emitSilenceableFailure(getLoc())
<< "target is expected to be module operation";

func::FuncOp funcOp =
targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
if (!funcOp)
return emitSilenceableFailure(getLoc())
<< "function with name '" << getFunctionName() << "' is not found";

auto transformationResult =
func::deduplicateArgsOfFuncOp(rewriter, funcOp, targetModuleOp);
if (failed(transformationResult))
return emitSilenceableFailure(getLoc())
<< "failed to deduplicate function arguments of function "
<< funcOp.getName();

auto [newFuncOp, newCallOp] = *transformationResult;

results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
results.set(cast<OpResult>(getTransformedFunction()), {newFuncOp});

return DiagnosedSilenceableFailure::success();
}

void transform::DeduplicateFuncArgsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getModuleMutable(), effects);
transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading