Skip to content

Commit 77e571e

Browse files
committed
add firstprivate/private to acc kernel
1 parent f0e1254 commit 77e571e

File tree

4 files changed

+85
-11
lines changed

4 files changed

+85
-11
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,11 +3024,10 @@ static Op createComputeOp(
30243024
}
30253025
addOperand(operands, operandSegments, ifCond);
30263026
addOperand(operands, operandSegments, selfCond);
3027-
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
3027+
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>)
30283028
addOperands(operands, operandSegments, reductionOperands);
3029-
addOperands(operands, operandSegments, privateOperands);
3030-
addOperands(operands, operandSegments, firstprivateOperands);
3031-
}
3029+
addOperands(operands, operandSegments, privateOperands);
3030+
addOperands(operands, operandSegments, firstprivateOperands);
30323031
addOperands(operands, operandSegments, dataClauseOperands);
30333032

30343033
Op computeOp;

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2002,8 +2002,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
20022002
corresponding `device_type` attributes must be modified as well.
20032003
}];
20042004

2005-
let arguments = (ins
2006-
Variadic<IntOrIndex>:$asyncOperands,
2005+
let arguments = (ins Variadic<IntOrIndex>:$asyncOperands,
20072006
OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
20082007
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
20092008
Variadic<IntOrIndex>:$waitOperands,
@@ -2018,12 +2017,11 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
20182017
OptionalAttr<DeviceTypeArrayAttr>:$numWorkersDeviceType,
20192018
Variadic<IntOrIndex>:$vectorLength,
20202019
OptionalAttr<DeviceTypeArrayAttr>:$vectorLengthDeviceType,
2021-
Optional<I1>:$ifCond,
2022-
Optional<I1>:$selfCond,
2023-
UnitAttr:$selfAttr,
2020+
Optional<I1>:$ifCond, Optional<I1>:$selfCond, UnitAttr:$selfAttr,
2021+
Variadic<OpenACC_AnyPointerOrMappableType>:$privateOperands,
2022+
Variadic<OpenACC_AnyPointerOrMappableType>:$firstprivateOperands,
20242023
Variadic<OpenACC_AnyPointerOrMappableType>:$dataClauseOperands,
2025-
OptionalAttr<DefaultValueAttr>:$defaultAttr,
2026-
UnitAttr:$combined);
2024+
OptionalAttr<DefaultValueAttr>:$defaultAttr, UnitAttr:$combined);
20272025

20282026
let regions = (region AnyRegion:$region);
20292027

@@ -2111,6 +2109,14 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
21112109
/// types.
21122110
void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
21132111
llvm::ArrayRef<DeviceType>);
2112+
2113+
/// Adds a private clause variable to this operation, including its recipe.
2114+
void addPrivatization(MLIRContext *, mlir::acc::PrivateOp op,
2115+
mlir::acc::PrivateRecipeOp recipe);
2116+
/// Adds a firstprivate clause variable to this operation, including its
2117+
/// recipe.
2118+
void addFirstPrivatization(MLIRContext *, mlir::acc::FirstprivateOp op,
2119+
mlir::acc::FirstprivateRecipeOp recipe);
21142120
}];
21152121

21162122
let assemblyFormat = [{
@@ -2119,10 +2125,12 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
21192125
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
21202126
| `async` `` custom<DeviceTypeOperandsWithKeywordOnly>($asyncOperands,
21212127
type($asyncOperands), $asyncOperandsDeviceType, $asyncOnly)
2128+
| `firstprivate` `(` $firstprivateOperands `:` type($firstprivateOperands) `)`
21222129
| `num_gangs` `(` custom<NumGangs>($numGangs,
21232130
type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
21242131
| `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
21252132
type($numWorkers), $numWorkersDeviceType) `)`
2133+
| `private` `(` $privateOperands `:` type($privateOperands) `)`
21262134
| `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
21272135
type($vectorLength), $vectorLengthDeviceType) `)`
21282136
| `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2675,6 +2675,20 @@ LogicalResult acc::KernelsOp::verify() {
26752675
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
26762676
}
26772677

2678+
void acc::KernelsOp::addPrivatization(MLIRContext *context,
2679+
mlir::acc::PrivateOp op,
2680+
mlir::acc::PrivateRecipeOp recipe) {
2681+
op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2682+
getPrivateOperandsMutable().append(op.getResult());
2683+
}
2684+
2685+
void acc::KernelsOp::addFirstPrivatization(
2686+
MLIRContext *context, mlir::acc::FirstprivateOp op,
2687+
mlir::acc::FirstprivateRecipeOp recipe) {
2688+
op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2689+
getFirstprivateOperandsMutable().append(op.getResult());
2690+
}
2691+
26782692
void acc::KernelsOp::addNumWorkersOperand(
26792693
MLIRContext *context, mlir::Value newValue,
26802694
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {

mlir/test/Dialect/OpenACC/ops.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,59 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
731731

732732
// -----
733733

734+
// Test acc.kernels with private and firstprivate operands, similar to acc.serial.
735+
736+
acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
737+
^bb0(%arg0: memref<10xf32>):
738+
%0 = memref.alloc() : memref<10xf32>
739+
acc.yield %0 : memref<10xf32>
740+
} destroy {
741+
^bb0(%arg0: memref<10xf32>):
742+
memref.dealloc %arg0 : memref<10xf32>
743+
acc.terminator
744+
}
745+
746+
acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init {
747+
^bb0(%arg0: memref<10x10xf32>):
748+
%1 = memref.alloc() : memref<10x10xf32>
749+
acc.yield %1 : memref<10x10xf32>
750+
} destroy {
751+
^bb0(%arg0: memref<10x10xf32>):
752+
memref.dealloc %arg0 : memref<10x10xf32>
753+
acc.terminator
754+
}
755+
756+
acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init {
757+
^bb0(%arg0: memref<10xf32>):
758+
%2 = memref.alloca() : memref<10xf32>
759+
acc.yield %2 : memref<10xf32>
760+
} copy {
761+
^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>):
762+
memref.copy %arg0, %arg1 : memref<10xf32> to memref<10xf32>
763+
acc.terminator
764+
} destroy {
765+
^bb0(%arg0: memref<10xf32>):
766+
acc.terminator
767+
}
768+
769+
func.func @testkernelspriv(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
770+
%priv_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
771+
%priv_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
772+
%firstp = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
773+
acc.kernels firstprivate(%firstp : memref<10xf32>) private(%priv_a, %priv_c : memref<10xf32>, memref<10x10xf32>) {
774+
}
775+
return
776+
}
777+
778+
// CHECK-LABEL: func.func @testkernelspriv(
779+
// CHECK: %[[PRIV_A:.*]] = acc.private varPtr(%{{.*}} : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32>
780+
// CHECK: %[[PRIV_C:.*]] = acc.private varPtr(%{{.*}} : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
781+
// CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr(%{{.*}} : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
782+
// CHECK: acc.kernels firstprivate(%[[FIRSTP]] : memref<10xf32>) private(%[[PRIV_A]], %[[PRIV_C]] : memref<10xf32>, memref<10x10xf32>) {
783+
// CHECK-NEXT: }
784+
785+
// -----
786+
734787
func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
735788
%ifCond = arith.constant true
736789

0 commit comments

Comments
 (0)