Skip to content

Commit 78ef032

Browse files
authored
[mlir][flang][openacc] Add device_type support for update op (#78764)
Add support for device_type information on the acc.update operation and update lowering from Flang.
1 parent 82d335e commit 78ef032

File tree

6 files changed

+289
-97
lines changed

6 files changed

+289
-97
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3001,27 +3001,42 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
30013001
}
30023002
}
30033003

3004+
static inline mlir::ArrayAttr
3005+
getArrayAttr(fir::FirOpBuilder &b,
3006+
llvm::SmallVector<mlir::Attribute> &attributes) {
3007+
return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
3008+
}
3009+
3010+
static inline mlir::DenseI32ArrayAttr
3011+
getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
3012+
llvm::SmallVector<int32_t> &values) {
3013+
return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values);
3014+
}
3015+
30043016
static void
30053017
genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
30063018
mlir::Location currentLocation,
30073019
Fortran::semantics::SemanticsContext &semanticsContext,
30083020
Fortran::lower::StatementContext &stmtCtx,
30093021
const Fortran::parser::AccClauseList &accClauseList) {
3010-
mlir::Value ifCond, async, waitDevnum;
3022+
mlir::Value ifCond, waitDevnum;
30113023
llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
3012-
waitOperands, deviceTypeOperands;
3013-
llvm::SmallVector<mlir::Attribute> deviceTypes;
3014-
3015-
// Async and wait clause have optional values but can be present with
3016-
// no value as well. When there is no value, the op has an attribute to
3017-
// represent the clause.
3018-
bool addAsyncAttr = false;
3019-
bool addWaitAttr = false;
3020-
bool addIfPresentAttr = false;
3024+
waitOperands, deviceTypeOperands, asyncOperands;
3025+
llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
3026+
asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
3027+
llvm::SmallVector<int32_t> waitOperandsSegments;
30213028

30223029
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
30233030

3024-
// Lower clauses values mapped to operands.
3031+
// device_type attribute is set to `none` until a device_type clause is
3032+
// encountered.
3033+
llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
3034+
crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
3035+
builder.getContext(), mlir::acc::DeviceType::None));
3036+
3037+
bool ifPresent = false;
3038+
3039+
// Lower clauses values mapped to operands and array attributes.
30253040
// Keep track of each group of operands separately as clauses can appear
30263041
// more than once.
30273042
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -3031,15 +3046,19 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
30313046
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
30323047
} else if (const auto *asyncClause =
30333048
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3034-
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
3049+
genAsyncClause(converter, asyncClause, asyncOperands,
3050+
asyncOperandsDeviceTypes, asyncOnlyDeviceTypes,
3051+
crtDeviceTypes, stmtCtx);
30353052
} else if (const auto *waitClause =
30363053
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
3037-
genWaitClause(converter, waitClause, waitOperands, waitDevnum,
3038-
addWaitAttr, stmtCtx);
3054+
genWaitClause(converter, waitClause, waitOperands,
3055+
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
3056+
waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
30393057
} else if (const auto *deviceTypeClause =
30403058
std::get_if<Fortran::parser::AccClause::DeviceType>(
30413059
&clause.u)) {
3042-
gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
3060+
crtDeviceTypes.clear();
3061+
gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
30433062
} else if (const auto *hostClause =
30443063
std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
30453064
genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -3053,7 +3072,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
30533072
dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
30543073
/*implicit=*/false);
30553074
} else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
3056-
addIfPresentAttr = true;
3075+
ifPresent = true;
30573076
} else if (const auto *selfClause =
30583077
std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
30593078
const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
@@ -3070,30 +3089,17 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
30703089

30713090
dataClauseOperands.append(updateHostOperands);
30723091

3073-
// Prepare the operand segment size attribute and the operands value range.
3074-
llvm::SmallVector<mlir::Value> operands;
3075-
llvm::SmallVector<int32_t> operandSegments;
3076-
addOperand(operands, operandSegments, ifCond);
3077-
addOperand(operands, operandSegments, async);
3078-
addOperand(operands, operandSegments, waitDevnum);
3079-
addOperands(operands, operandSegments, waitOperands);
3080-
addOperands(operands, operandSegments, dataClauseOperands);
3081-
3082-
mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
3083-
builder, currentLocation, operands, operandSegments);
3084-
if (!deviceTypes.empty())
3085-
updateOp.setDeviceTypesAttr(
3086-
mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
3092+
builder.create<mlir::acc::UpdateOp>(
3093+
currentLocation, ifCond, asyncOperands,
3094+
getArrayAttr(builder, asyncOperandsDeviceTypes),
3095+
getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
3096+
getDenseI32ArrayAttr(builder, waitOperandsSegments),
3097+
getArrayAttr(builder, waitOperandsDeviceTypes),
3098+
getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
3099+
ifPresent);
30873100

30883101
genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
30893102
builder, updateHostOperands, /*structured=*/false);
3090-
3091-
if (addAsyncAttr)
3092-
updateOp.setAsyncAttr(builder.getUnitAttr());
3093-
if (addWaitAttr)
3094-
updateOp.setWaitAttr(builder.getUnitAttr());
3095-
if (addIfPresentAttr)
3096-
updateOp.setIfPresentAttr(builder.getUnitAttr());
30973103
}
30983104

30993105
static void

flang/test/Lower/OpenACC/acc-update.f90

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,17 @@ subroutine acc_update
6161

6262
!$acc update host(a) async
6363
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
64-
! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
64+
! CHECK: acc.update async() dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
6565
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
6666

6767
!$acc update host(a) wait
6868
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
69-
! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
69+
! CHECK: acc.update wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
7070
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
7171

7272
!$acc update host(a) async wait
7373
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
74-
! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
74+
! CHECK: acc.update async() wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
7575
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
7676

7777
!$acc update host(a) async(1)
@@ -89,32 +89,27 @@ subroutine acc_update
8989
!$acc update host(a) wait(1)
9090
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
9191
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
92-
! CHECK: acc.update wait([[WAIT1]] : i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
92+
! CHECK: acc.update wait({[[WAIT1]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
9393
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
9494

9595
!$acc update host(a) wait(queues: 1, 2)
9696
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
9797
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
9898
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
99-
! CHECK: acc.update wait([[WAIT2]], [[WAIT3]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
99+
! CHECK: acc.update wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
100100
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
101101

102102
!$acc update host(a) wait(devnum: 1: queues: 1, 2)
103103
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
104104
! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
105105
! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
106106
! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
107-
! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
107+
! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
108108
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
109109

110-
!$acc update host(a) device_type(default, host)
110+
!$acc update host(a) device_type(host, nvidia) async
111111
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
112-
! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<default>, #acc.device_type<host>]}
113-
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
114-
115-
!$acc update host(a) device_type(*)
116-
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
117-
! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<star>]}
112+
! CHECK: acc.update async([#acc.device_type<host>, #acc.device_type<nvidia>]) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
118113
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
119114

120115
end subroutine acc_update

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,29 +2196,58 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
21962196
}];
21972197

21982198
let arguments = (ins Optional<I1>:$ifCond,
2199-
Optional<IntOrIndex>:$asyncOperand,
2200-
Optional<IntOrIndex>:$waitDevnum,
2201-
Variadic<IntOrIndex>:$waitOperands,
2202-
UnitAttr:$async,
2203-
UnitAttr:$wait,
2204-
OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
2205-
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
2206-
UnitAttr:$ifPresent);
2199+
Variadic<IntOrIndex>:$asyncOperands,
2200+
OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
2201+
OptionalAttr<DeviceTypeArrayAttr>:$async,
2202+
Optional<IntOrIndex>:$waitDevnum,
2203+
Variadic<IntOrIndex>:$waitOperands,
2204+
OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
2205+
OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
2206+
OptionalAttr<DeviceTypeArrayAttr>:$wait,
2207+
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
2208+
UnitAttr:$ifPresent);
22072209

22082210
let extraClassDeclaration = [{
22092211
/// The number of data operands.
22102212
unsigned getNumDataOperands();
22112213

22122214
/// The i-th data operand passed.
22132215
Value getDataOperand(unsigned i);
2216+
2217+
/// Return true if the op has the async attribute for the
2218+
/// mlir::acc::DeviceType::None device_type.
2219+
bool hasAsyncOnly();
2220+
/// Return true if the op has the async attribute for the given device_type.
2221+
bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
2222+
/// Return the value of the async clause if present.
2223+
mlir::Value getAsyncValue();
2224+
/// Return the value of the async clause for the given device_type if
2225+
/// present.
2226+
mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
2227+
2228+
/// Return true if the op has the wait attribute for the
2229+
/// mlir::acc::DeviceType::None device_type.
2230+
bool hasWaitOnly();
2231+
/// Return true if the op has the wait attribute for the given device_type.
2232+
bool hasWaitOnly(mlir::acc::DeviceType deviceType);
2233+
/// Return the values of the wait clause if present.
2234+
mlir::Operation::operand_range getWaitValues();
2235+
/// Return the values of the wait clause for the given device_type if
2236+
/// present.
2237+
mlir::Operation::operand_range
2238+
getWaitValues(mlir::acc::DeviceType deviceType);
22142239
}];
22152240

22162241
let assemblyFormat = [{
22172242
oilist(
22182243
`if` `(` $ifCond `)`
2219-
| `async` `(` $asyncOperand `:` type($asyncOperand) `)`
2244+
| `async` `` custom<DeviceTypeOperandsWithKeywordOnly>(
2245+
$asyncOperands, type($asyncOperands),
2246+
$asyncOperandsDeviceType, $async)
22202247
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
2221-
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
2248+
| `wait` `` custom<WaitClause>($waitOperands,
2249+
type($waitOperands), $waitOperandsDeviceType,
2250+
$waitOperandsSegments, $wait)
22222251
| `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
22232252
)
22242253
attr-dict-with-keyword

0 commit comments

Comments
 (0)