Skip to content

Commit 91f1242

Browse files
committed
[flang][cuda] Update target rewrite to work on gpu.func
1 parent f9e1150 commit 91f1242

File tree

2 files changed

+112
-65
lines changed

2 files changed

+112
-65
lines changed

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 99 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,21 @@ struct FixupTy {
6262
FixupTy(Codes code, std::size_t index,
6363
std::function<void(mlir::func::FuncOp)> &&finalizer)
6464
: code{code}, index{index}, finalizer{finalizer} {}
65+
FixupTy(Codes code, std::size_t index,
66+
std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
67+
: code{code}, index{index}, gpuFinalizer{finalizer} {}
6568
FixupTy(Codes code, std::size_t index, std::size_t second,
6669
std::function<void(mlir::func::FuncOp)> &&finalizer)
6770
: code{code}, index{index}, second{second}, finalizer{finalizer} {}
71+
FixupTy(Codes code, std::size_t index, std::size_t second,
72+
std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
73+
: code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
6874

6975
Codes code;
7076
std::size_t index;
7177
std::size_t second{};
7278
std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
79+
std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
7380
}; // namespace
7481

7582
/// Target-specific rewriting of the FIR. This is a prerequisite pass to code
@@ -722,9 +729,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
722729
convertSignature(fn);
723730
}
724731

725-
for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
732+
for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
726733
for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
727734
convertSignature(fn);
735+
for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
736+
convertSignature(fn);
737+
}
728738

729739
return mlir::success();
730740
}
@@ -770,17 +780,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
770780

771781
/// Determine if the signature has host associations. The host association
772782
/// argument may need special target specific rewriting.
773-
static bool hasHostAssociations(mlir::func::FuncOp func) {
783+
template <typename OpTy>
784+
static bool hasHostAssociations(OpTy func) {
774785
std::size_t end = func.getFunctionType().getInputs().size();
775786
for (std::size_t i = 0; i < end; ++i)
776-
if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
787+
if (func.template getArgAttrOfType<mlir::UnitAttr>(
788+
i, fir::getHostAssocAttrName()))
777789
return true;
778790
return false;
779791
}
780792

781793
/// Rewrite the signatures and body of the `FuncOp`s in the module for
782794
/// the immediately subsequent target code gen.
783-
void convertSignature(mlir::func::FuncOp func) {
795+
template <typename OpTy>
796+
void convertSignature(OpTy func) {
784797
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
785798
if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
786799
return;
@@ -805,13 +818,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
805818
// Convert return value(s)
806819
for (auto ty : funcTy.getResults())
807820
llvm::TypeSwitch<mlir::Type>(ty)
808-
.Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
821+
.template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
809822
if (noComplexConversion)
810823
newResTys.push_back(cmplx);
811824
else
812825
doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
813826
})
814-
.Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
827+
.template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
815828
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
816829
assert(m.size() == 1);
817830
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -825,10 +838,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
825838
rewriter->getUnitAttr()));
826839
newResTys.push_back(retTy);
827840
})
828-
.Case<fir::RecordType>([&](fir::RecordType recTy) {
841+
.template Case<fir::RecordType>([&](fir::RecordType recTy) {
829842
doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
830843
})
831-
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
844+
.template Default([&](mlir::Type ty) { newResTys.push_back(ty); });
832845

833846
// Saved potential shift in argument. Handling of result can add arguments
834847
// at the beginning of the function signature.
@@ -840,7 +853,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
840853
auto ty = e.value();
841854
unsigned index = e.index();
842855
llvm::TypeSwitch<mlir::Type>(ty)
843-
.Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
856+
.template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
844857
if (noCharacterConversion) {
845858
newInTyAndAttrs.push_back(
846859
fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
@@ -863,10 +876,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
863876
}
864877
}
865878
})
866-
.Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
879+
.template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
867880
doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
868881
})
869-
.Case<mlir::TupleType>([&](mlir::TupleType tuple) {
882+
.template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
870883
if (fir::isCharacterProcedureTuple(tuple)) {
871884
fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
872885
newInTyAndAttrs.size(), trailingTys.size());
@@ -878,7 +891,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
878891
fir::CodeGenSpecifics::getTypeAndAttr(ty));
879892
}
880893
})
881-
.Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
894+
.template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
882895
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
883896
assert(m.size() == 1);
884897
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -887,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
887900
if (!extensionAttrName.empty() &&
888901
isFuncWithCCallingConvention(func))
889902
fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
890-
[=](mlir::func::FuncOp func) {
903+
[=](OpTy func) {
891904
func.setArgAttr(
892905
argNo, extensionAttrName,
893906
mlir::UnitAttr::get(func.getContext()));
@@ -898,13 +911,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
898911
.template Case<fir::RecordType>([&](fir::RecordType recTy) {
899912
doStructArg(func, recTy, newInTyAndAttrs, fixups);
900913
})
901-
.Default([&](mlir::Type ty) {
914+
.template Default([&](mlir::Type ty) {
902915
newInTyAndAttrs.push_back(
903916
fir::CodeGenSpecifics::getTypeAndAttr(ty));
904917
});
905918

906-
if (func.getArgAttrOfType<mlir::UnitAttr>(index,
907-
fir::getHostAssocAttrName())) {
919+
if (func.template getArgAttrOfType<mlir::UnitAttr>(
920+
index, fir::getHostAssocAttrName())) {
908921
extraAttrs.push_back(
909922
{newInTyAndAttrs.size() - 1,
910923
rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
@@ -979,29 +992,52 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
979992
auto newArg =
980993
func.front().insertArgument(fixup.index, fixupType, loc);
981994
offset++;
982-
func.walk([&](mlir::func::ReturnOp ret) {
983-
rewriter->setInsertionPoint(ret);
984-
auto oldOper = ret.getOperand(0);
985-
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
986-
auto cast =
987-
rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
988-
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
989-
rewriter->create<mlir::func::ReturnOp>(loc);
990-
ret.erase();
991-
});
995+
if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
996+
func.walk([&](mlir::func::ReturnOp ret) {
997+
rewriter->setInsertionPoint(ret);
998+
auto oldOper = ret.getOperand(0);
999+
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
1000+
auto cast =
1001+
rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
1002+
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
1003+
rewriter->create<mlir::func::ReturnOp>(loc);
1004+
ret.erase();
1005+
});
1006+
if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1007+
func.walk([&](mlir::gpu::ReturnOp ret) {
1008+
rewriter->setInsertionPoint(ret);
1009+
auto oldOper = ret.getOperand(0);
1010+
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
1011+
auto cast =
1012+
rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
1013+
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
1014+
rewriter->create<mlir::gpu::ReturnOp>(loc);
1015+
ret.erase();
1016+
});
9921017
} break;
9931018
case FixupTy::Codes::ReturnType: {
9941019
// The function is still returning a value, but its type has likely
9951020
// changed to suit the target ABI convention.
996-
func.walk([&](mlir::func::ReturnOp ret) {
997-
rewriter->setInsertionPoint(ret);
998-
auto oldOper = ret.getOperand(0);
999-
mlir::Value bitcast =
1000-
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
1001-
/*inputMayBeBigger=*/false);
1002-
rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
1003-
ret.erase();
1004-
});
1021+
if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
1022+
func.walk([&](mlir::func::ReturnOp ret) {
1023+
rewriter->setInsertionPoint(ret);
1024+
auto oldOper = ret.getOperand(0);
1025+
mlir::Value bitcast =
1026+
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
1027+
/*inputMayBeBigger=*/false);
1028+
rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
1029+
ret.erase();
1030+
});
1031+
if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1032+
func.walk([&](mlir::gpu::ReturnOp ret) {
1033+
rewriter->setInsertionPoint(ret);
1034+
auto oldOper = ret.getOperand(0);
1035+
mlir::Value bitcast =
1036+
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
1037+
/*inputMayBeBigger=*/false);
1038+
rewriter->create<mlir::gpu::ReturnOp>(loc, bitcast);
1039+
ret.erase();
1040+
});
10051041
} break;
10061042
case FixupTy::Codes::Split: {
10071043
// The FIR argument has been split into a pair of distinct arguments
@@ -1101,13 +1137,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11011137
}
11021138
}
11031139

1104-
for (auto &fixup : fixups)
1105-
if (fixup.finalizer)
1106-
(*fixup.finalizer)(func);
1140+
for (auto &fixup : fixups) {
1141+
if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
1142+
if (fixup.finalizer)
1143+
(*fixup.finalizer)(func);
1144+
if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1145+
if (fixup.gpuFinalizer)
1146+
(*fixup.gpuFinalizer)(func);
1147+
}
11071148
}
11081149

1109-
template <typename Ty, typename FIXUPS>
1110-
void doReturn(mlir::func::FuncOp func, Ty &newResTys,
1150+
template <typename OpTy, typename Ty, typename FIXUPS>
1151+
void doReturn(OpTy func, Ty &newResTys,
11111152
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11121153
FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
11131154
assert(m.size() == 1 &&
@@ -1119,7 +1160,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11191160
unsigned argNo = newInTyAndAttrs.size();
11201161
if (auto align = attr.getAlignment())
11211162
fixups.emplace_back(
1122-
FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
1163+
FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
11231164
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
11241165
func.getFunctionType().getInput(argNo));
11251166
func.setArgAttr(argNo, "llvm.sret",
@@ -1130,7 +1171,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11301171
});
11311172
else
11321173
fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
1133-
[=](mlir::func::FuncOp func) {
1174+
[=](OpTy func) {
11341175
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
11351176
func.getFunctionType().getInput(argNo));
11361177
func.setArgAttr(argNo, "llvm.sret",
@@ -1141,8 +1182,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11411182
}
11421183
if (auto align = attr.getAlignment())
11431184
fixups.emplace_back(
1144-
FixupTy::Codes::ReturnType, newResTys.size(),
1145-
[=](mlir::func::FuncOp func) {
1185+
FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) {
11461186
func.setArgAttr(
11471187
newResTys.size(), "llvm.align",
11481188
rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
@@ -1155,9 +1195,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11551195
/// Convert a complex return value. This can involve converting the return
11561196
/// value to a "hidden" first argument or packing the complex into a wide
11571197
/// GPR.
1158-
template <typename Ty, typename FIXUPS>
1159-
void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
1160-
Ty &newResTys,
1198+
template <typename OpTy, typename Ty, typename FIXUPS>
1199+
void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
11611200
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11621201
FIXUPS &fixups) {
11631202
if (noComplexConversion) {
@@ -1169,9 +1208,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11691208
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
11701209
}
11711210

1172-
template <typename Ty, typename FIXUPS>
1173-
void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
1174-
Ty &newResTys,
1211+
template <typename OpTy, typename Ty, typename FIXUPS>
1212+
void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys,
11751213
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11761214
FIXUPS &fixups) {
11771215
if (noStructConversion) {
@@ -1182,12 +1220,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11821220
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
11831221
}
11841222

1185-
template <typename FIXUPS>
1186-
void
1187-
createFuncOpArgFixups(mlir::func::FuncOp func,
1188-
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1189-
fir::CodeGenSpecifics::Marshalling &argsInTys,
1190-
FIXUPS &fixups) {
1223+
template <typename OpTy, typename FIXUPS>
1224+
void createFuncOpArgFixups(
1225+
OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1226+
fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
11911227
const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
11921228
: FixupTy::Codes::ArgumentType;
11931229
for (auto e : llvm::enumerate(argsInTys)) {
@@ -1198,7 +1234,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11981234
if (attr.isByVal()) {
11991235
if (auto align = attr.getAlignment())
12001236
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
1201-
[=](mlir::func::FuncOp func) {
1237+
[=](OpTy func) {
12021238
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
12031239
func.getFunctionType().getInput(argNo));
12041240
func.setArgAttr(argNo, "llvm.byval",
@@ -1210,8 +1246,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12101246
});
12111247
else
12121248
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
1213-
newInTyAndAttrs.size(),
1214-
[=](mlir::func::FuncOp func) {
1249+
newInTyAndAttrs.size(), [=](OpTy func) {
12151250
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
12161251
func.getFunctionType().getInput(argNo));
12171252
func.setArgAttr(argNo, "llvm.byval",
@@ -1220,7 +1255,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12201255
} else {
12211256
if (auto align = attr.getAlignment())
12221257
fixups.emplace_back(
1223-
fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
1258+
fixupCode, argNo, index, [=](OpTy func) {
12241259
func.setArgAttr(argNo, "llvm.align",
12251260
rewriter->getIntegerAttr(
12261261
rewriter->getIntegerType(32), align));
@@ -1235,8 +1270,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12351270
/// Convert a complex argument value. This can involve storing the value to
12361271
/// a temporary memory location or factoring the value into two distinct
12371272
/// arguments.
1238-
template <typename FIXUPS>
1239-
void doComplexArg(mlir::func::FuncOp func, mlir::ComplexType cmplx,
1273+
template <typename OpTy, typename FIXUPS>
1274+
void doComplexArg(OpTy func, mlir::ComplexType cmplx,
12401275
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
12411276
FIXUPS &fixups) {
12421277
if (noComplexConversion) {
@@ -1248,8 +1283,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12481283
createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
12491284
}
12501285

1251-
template <typename FIXUPS>
1252-
void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy,
1286+
template <typename OpTy, typename FIXUPS>
1287+
void doStructArg(OpTy func, fir::RecordType recTy,
12531288
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
12541289
FIXUPS &fixups) {
12551290
if (noStructConversion) {

flang/test/Fir/CUDA/cuda-target-rewrite.mlir

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// REQUIRES: x86-registered-target
2-
// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
2+
// RUN: fir-opt --split-input-file --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
33

44
gpu.module @testmod {
55
gpu.func @_QPvcpowdk(%arg0: !fir.ref<complex<f64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
@@ -15,3 +15,15 @@ gpu.module @testmod {
1515
// CHECK-LABEL: gpu.func @_QPvcpowdk
1616
// CHECK: %{{.*}} = fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}, %{{.*}}) : (f64, f64, i64) -> tuple<f64, f64>
1717
// CHECK: func.func private @_FortranAzpowk(f64, f64, i64) -> tuple<f64, f64> attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime}
18+
19+
// -----
20+
21+
gpu.module @testmod {
22+
gpu.func @_QPtest(%arg0: complex<f64>) -> (complex<f64>) {
23+
gpu.return %arg0 : complex<f64>
24+
}
25+
}
26+
27+
// CHECK-LABEL: gpu.func @_QPtest
28+
// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
29+
// CHECK: gpu.return %{{.*}} : tuple<f64, f64>

0 commit comments

Comments
 (0)