Skip to content

Commit 36fae81

Browse files
committed
[flang] Update target rewrite to support workgroup and private attributions
1 parent 2219119 commit 36fae81

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
872872
}
873873
}
874874

875+
// Count the number of arguments that have to stay in place at the end of
876+
// the argument list.
877+
unsigned trailingArgs = 0;
878+
if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) {
879+
trailingArgs =
880+
func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions();
881+
}
882+
875883
// Convert return value(s)
876884
for (auto ty : funcTy.getResults())
877885
llvm::TypeSwitch<mlir::Type>(ty)
@@ -981,6 +989,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
981989
}
982990
}
983991

992+
// Add the argument at the end if the number of trailing arguments is 0,
993+
// otherwise insert the argument at the appropriate index.
994+
auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) {
995+
unsigned inputIndex = func.front().getArguments().size() - trailingArgs;
996+
auto newArg = trailingArgs == 0
997+
? func.front().addArgument(ty, loc)
998+
: func.front().insertArgument(inputIndex, ty, loc);
999+
return newArg;
1000+
};
1001+
9841002
if (!func.empty()) {
9851003
// If the function has a body, then apply the fixups to the arguments and
9861004
// return ops as required. These fixups are done in place.
@@ -1117,8 +1135,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11171135
// original arguments. (Boxchar arguments.)
11181136
auto newBufArg =
11191137
func.front().insertArgument(fixup.index, fixupType, loc);
1120-
auto newLenArg =
1121-
func.front().addArgument(trailingTys[fixup.second], loc);
1138+
auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
11221139
auto boxTy = oldArgTys[fixup.index - offset];
11231140
rewriter->setInsertionPointToStart(&func.front());
11241141
auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg,
@@ -1133,8 +1150,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11331150
// appended after all the original arguments.
11341151
auto newProcPointerArg =
11351152
func.front().insertArgument(fixup.index, fixupType, loc);
1136-
auto newLenArg =
1137-
func.front().addArgument(trailingTys[fixup.second], loc);
1153+
auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
11381154
auto tupleType = oldArgTys[fixup.index - offset];
11391155
rewriter->setInsertionPointToStart(&func.front());
11401156
fir::FirOpBuilder builder(*rewriter, getModule());

flang/test/Fir/CUDA/cuda-target-rewrite.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,56 @@ func.func @main(%arg0: complex<f64>) {
5555
// CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
5656
// CHECK: gpu.return
5757
// CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64) {cuf.proc_attr = #cuf.cuda_proc<global>}
58+
59+
// -----
60+
61+
module attributes {gpu.container_module, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
62+
gpu.module @testmod {
63+
gpu.func @_QMbarPfoo(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
64+
%c0 = arith.constant 0 : index
65+
memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
66+
gpu.return
67+
}
68+
// CHECK-LABEL: gpu.func @_QMbarPfoo(
69+
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WORKGROUP:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
70+
// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
71+
// CHECK: memref.store %{{.*}}, %[[WORKGROUP]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
72+
73+
gpu.func @_QMbarPfoo2(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %arg4 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
74+
%c0 = arith.constant 0 : index
75+
memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
76+
memref.store %arg0, %arg4[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
77+
gpu.return
78+
}
79+
// CHECK-LABEL: gpu.func @_QMbarPfoo2(
80+
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG1:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %[[WG2:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
81+
// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
82+
// CHECK: memref.store %{{.*}}, %[[WG1]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
83+
// CHECK: memref.store %{{.*}}, %[[WG2]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
84+
85+
gpu.func @_QMbarPprivate(%arg0: f32, %arg1: !fir.boxchar<1>) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%arg3 : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
86+
%c0 = arith.constant 0 : index
87+
memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
88+
memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<private>>
89+
gpu.return
90+
}
91+
// CHECK-LABEL: gpu.func @_QMbarPprivate(
92+
// CHECK-SAME: %{{.*}}: f32, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%[[PRIVATE:.*]] : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
93+
// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
94+
// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
95+
// CHECK: memref.store %{{.*}}, %[[PRIVATE]][%{{.*}}] : memref<1xf32, #gpu.address_space<private>>
96+
97+
gpu.func @test_with_char_proc(%arg0: f32, %arg1: tuple<() -> (), i64> {fir.char_proc}) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>>) {
98+
%c0 = arith.constant 0 : index
99+
memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
100+
gpu.return
101+
}
102+
// CHECK-LABEL: gpu.func @test_with_char_proc(
103+
// CHECK-SAME: %{{.*}}: f32, %[[CHARPROC:.*]]: () -> () {fir.char_proc}, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>>) {
104+
// CHECK: %{{.*}} = fir.undefined tuple<() -> (), i64>
105+
// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[CHARPROC]], [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64>
106+
// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[LENGTH]], [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64>
107+
// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
108+
}
109+
}
110+

flang/tools/fir-opt/fir-opt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ int main(int argc, char **argv) {
5050
#endif
5151
DialectRegistry registry;
5252
fir::support::registerDialects(registry);
53+
registry.insert<mlir::memref::MemRefDialect>();
5354
fir::support::addFIRExtensions(registry);
5455
return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n",
5556
registry));

0 commit comments

Comments
 (0)