Skip to content

Commit 47ea854

Browse files
authored
[flang] Update target rewrite to support workgroup and private attributions (#164515)
Some operations like the gpu.func have arguments that need to stay in place while rewriting the signature. This is the case for the workgroup and private attribution. Update the target rewrite pass to be aware of that when adding argument at the end of the function signature. If any trailing arguments are present, the new argument will be inserted just before them.
1 parent 23ead47 commit 47ea854

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
872872
}
873873
}
874874

875+
// Count the number of arguments that have to stay in place at the end of
876+
// the argument list.
877+
unsigned trailingArgs = 0;
878+
if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) {
879+
trailingArgs =
880+
func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions();
881+
}
882+
875883
// Convert return value(s)
876884
for (auto ty : funcTy.getResults())
877885
llvm::TypeSwitch<mlir::Type>(ty)
@@ -981,6 +989,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
981989
}
982990
}
983991

992+
// Add the argument at the end if the number of trailing arguments is 0,
993+
// otherwise insert the argument at the appropriate index.
994+
auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) {
995+
unsigned inputIndex = func.front().getArguments().size() - trailingArgs;
996+
auto newArg = trailingArgs == 0
997+
? func.front().addArgument(ty, loc)
998+
: func.front().insertArgument(inputIndex, ty, loc);
999+
return newArg;
1000+
};
1001+
9841002
if (!func.empty()) {
9851003
// If the function has a body, then apply the fixups to the arguments and
9861004
// return ops as required. These fixups are done in place.
@@ -1117,8 +1135,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11171135
// original arguments. (Boxchar arguments.)
11181136
auto newBufArg =
11191137
func.front().insertArgument(fixup.index, fixupType, loc);
1120-
auto newLenArg =
1121-
func.front().addArgument(trailingTys[fixup.second], loc);
1138+
auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
11221139
auto boxTy = oldArgTys[fixup.index - offset];
11231140
rewriter->setInsertionPointToStart(&func.front());
11241141
auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg,
@@ -1133,8 +1150,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11331150
// appended after all the original arguments.
11341151
auto newProcPointerArg =
11351152
func.front().insertArgument(fixup.index, fixupType, loc);
1136-
auto newLenArg =
1137-
func.front().addArgument(trailingTys[fixup.second], loc);
1153+
auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
11381154
auto tupleType = oldArgTys[fixup.index - offset];
11391155
rewriter->setInsertionPointToStart(&func.front());
11401156
fir::FirOpBuilder builder(*rewriter, getModule());

flang/test/Fir/CUDA/cuda-target-rewrite.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,56 @@ func.func @main(%arg0: complex<f64>) {
5555
// CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
5656
// CHECK: gpu.return
5757
// CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64) {cuf.proc_attr = #cuf.cuda_proc<global>}
58+
59+
// -----
60+
61+
module attributes {gpu.container_module, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
62+
gpu.module @testmod {
63+
gpu.func @_QMbarPfoo(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
64+
%c0 = arith.constant 0 : index
65+
memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
66+
gpu.return
67+
}
68+
// CHECK-LABEL: gpu.func @_QMbarPfoo(
69+
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WORKGROUP:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
70+
// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
71+
// CHECK: memref.store %{{.*}}, %[[WORKGROUP]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
72+
73+
gpu.func @_QMbarPfoo2(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %arg4 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
74+
%c0 = arith.constant 0 : index
75+
memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
76+
memref.store %arg0, %arg4[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
77+
gpu.return
78+
}
79+
// CHECK-LABEL: gpu.func @_QMbarPfoo2(
80+
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG1:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %[[WG2:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
81+
// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
82+
// CHECK: memref.store %{{.*}}, %[[WG1]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
83+
// CHECK: memref.store %{{.*}}, %[[WG2]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
84+
85+
gpu.func @_QMbarPprivate(%arg0: f32, %arg1: !fir.boxchar<1>) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%arg3 : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
86+
%c0 = arith.constant 0 : index
87+
memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
88+
memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<private>>
89+
gpu.return
90+
}
91+
// CHECK-LABEL: gpu.func @_QMbarPprivate(
92+
// CHECK-SAME: %{{.*}}: f32, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%[[PRIVATE:.*]] : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
93+
// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
94+
// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
95+
// CHECK: memref.store %{{.*}}, %[[PRIVATE]][%{{.*}}] : memref<1xf32, #gpu.address_space<private>>
96+
97+
gpu.func @test_with_char_proc(%arg0: f32, %arg1: tuple<() -> (), i64> {fir.char_proc}) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>>) {
98+
%c0 = arith.constant 0 : index
99+
memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
100+
gpu.return
101+
}
102+
// CHECK-LABEL: gpu.func @test_with_char_proc(
103+
// CHECK-SAME: %{{.*}}: f32, %[[CHARPROC:.*]]: () -> () {fir.char_proc}, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>>) {
104+
// CHECK: %{{.*}} = fir.undefined tuple<() -> (), i64>
105+
// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[CHARPROC]], [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64>
106+
// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[LENGTH]], [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64>
107+
// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
108+
}
109+
}
110+

flang/tools/fir-opt/fir-opt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ int main(int argc, char **argv) {
5050
#endif
5151
DialectRegistry registry;
5252
fir::support::registerDialects(registry);
53+
registry.insert<mlir::memref::MemRefDialect>();
5354
fir::support::addFIRExtensions(registry);
5455
return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n",
5556
registry));

0 commit comments

Comments
 (0)