Skip to content

Commit d92c612

Browse files
committed
Remove duplication of code
1 parent 91f1242 commit d92c612

File tree

1 file changed

+27
-50
lines changed

1 file changed

+27
-50
lines changed

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -726,14 +726,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
726726
if (targetFeaturesAttr)
727727
fn->setAttr("target_features", targetFeaturesAttr);
728728

729-
convertSignature(fn);
729+
convertSignature<mlir::func::ReturnOp>(fn);
730730
}
731731

732732
for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
733733
for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
734-
convertSignature(fn);
734+
convertSignature<mlir::func::ReturnOp>(fn);
735735
for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
736-
convertSignature(fn);
736+
convertSignature<mlir::gpu::ReturnOp>(fn);
737737
}
738738

739739
return mlir::success();
@@ -792,8 +792,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
792792

793793
/// Rewrite the signatures and body of the `FuncOp`s in the module for
794794
/// the immediately subsequent target code gen.
795-
template <typename OpTy>
796-
void convertSignature(OpTy func) {
795+
template <typename ReturnOpTy, typename FuncOpTy>
796+
void convertSignature(FuncOpTy func) {
797797
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
798798
if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
799799
return;
@@ -900,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
900900
if (!extensionAttrName.empty() &&
901901
isFuncWithCCallingConvention(func))
902902
fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
903-
[=](OpTy func) {
903+
[=](FuncOpTy func) {
904904
func.setArgAttr(
905905
argNo, extensionAttrName,
906906
mlir::UnitAttr::get(func.getContext()));
@@ -992,52 +992,29 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
992992
auto newArg =
993993
func.front().insertArgument(fixup.index, fixupType, loc);
994994
offset++;
995-
if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
996-
func.walk([&](mlir::func::ReturnOp ret) {
997-
rewriter->setInsertionPoint(ret);
998-
auto oldOper = ret.getOperand(0);
999-
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
1000-
auto cast =
1001-
rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
1002-
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
1003-
rewriter->create<mlir::func::ReturnOp>(loc);
1004-
ret.erase();
1005-
});
1006-
if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1007-
func.walk([&](mlir::gpu::ReturnOp ret) {
1008-
rewriter->setInsertionPoint(ret);
1009-
auto oldOper = ret.getOperand(0);
1010-
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
1011-
auto cast =
1012-
rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
1013-
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
1014-
rewriter->create<mlir::gpu::ReturnOp>(loc);
1015-
ret.erase();
1016-
});
995+
func.walk([&](ReturnOpTy ret) {
996+
rewriter->setInsertionPoint(ret);
997+
auto oldOper = ret.getOperand(0);
998+
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
999+
auto cast =
1000+
rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
1001+
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
1002+
rewriter->create<ReturnOpTy>(loc);
1003+
ret.erase();
1004+
});
10171005
} break;
10181006
case FixupTy::Codes::ReturnType: {
10191007
// The function is still returning a value, but its type has likely
10201008
// changed to suit the target ABI convention.
1021-
if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
1022-
func.walk([&](mlir::func::ReturnOp ret) {
1023-
rewriter->setInsertionPoint(ret);
1024-
auto oldOper = ret.getOperand(0);
1025-
mlir::Value bitcast =
1026-
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
1027-
/*inputMayBeBigger=*/false);
1028-
rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
1029-
ret.erase();
1030-
});
1031-
if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1032-
func.walk([&](mlir::gpu::ReturnOp ret) {
1033-
rewriter->setInsertionPoint(ret);
1034-
auto oldOper = ret.getOperand(0);
1035-
mlir::Value bitcast =
1036-
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
1037-
/*inputMayBeBigger=*/false);
1038-
rewriter->create<mlir::gpu::ReturnOp>(loc, bitcast);
1039-
ret.erase();
1040-
});
1009+
func.walk([&](ReturnOpTy ret) {
1010+
rewriter->setInsertionPoint(ret);
1011+
auto oldOper = ret.getOperand(0);
1012+
mlir::Value bitcast =
1013+
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
1014+
/*inputMayBeBigger=*/false);
1015+
rewriter->create<ReturnOpTy>(loc, bitcast);
1016+
ret.erase();
1017+
});
10411018
} break;
10421019
case FixupTy::Codes::Split: {
10431020
// The FIR argument has been split into a pair of distinct arguments
@@ -1138,10 +1115,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11381115
}
11391116

11401117
for (auto &fixup : fixups) {
1141-
if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
1118+
if constexpr (std::is_same_v<FuncOpTy, mlir::func::FuncOp>)
11421119
if (fixup.finalizer)
11431120
(*fixup.finalizer)(func);
1144-
if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1121+
if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>)
11451122
if (fixup.gpuFinalizer)
11461123
(*fixup.gpuFinalizer)(func);
11471124
}

0 commit comments

Comments
 (0)