diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 5355ca60181b0..69c3300ba4390 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -3024,11 +3024,9 @@ static Op createComputeOp( } addOperand(operands, operandSegments, ifCond); addOperand(operands, operandSegments, selfCond); - if constexpr (!std::is_same_v) { - addOperands(operands, operandSegments, reductionOperands); - addOperands(operands, operandSegments, privateOperands); - addOperands(operands, operandSegments, firstprivateOperands); - } + addOperands(operands, operandSegments, reductionOperands); + addOperands(operands, operandSegments, privateOperands); + addOperands(operands, operandSegments, firstprivateOperands); addOperands(operands, operandSegments, dataClauseOperands); Op computeOp; diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 77d1a6f8d53b5..fcfe959709f09 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -2002,8 +2002,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", corresponding `device_type` attributes must be modified as well. }]; - let arguments = (ins - Variadic:$asyncOperands, + let arguments = (ins Variadic:$asyncOperands, OptionalAttr:$asyncOperandsDeviceType, OptionalAttr:$asyncOnly, Variadic:$waitOperands, @@ -2018,12 +2017,12 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", OptionalAttr:$numWorkersDeviceType, Variadic:$vectorLength, OptionalAttr:$vectorLengthDeviceType, - Optional:$ifCond, - Optional:$selfCond, - UnitAttr:$selfAttr, + Optional:$ifCond, Optional:$selfCond, UnitAttr:$selfAttr, + Variadic:$reductionOperands, + Variadic:$privateOperands, + Variadic:$firstprivateOperands, Variadic:$dataClauseOperands, - OptionalAttr:$defaultAttr, - UnitAttr:$combined); + OptionalAttr:$defaultAttr, UnitAttr:$combined); let regions = (region AnyRegion:$region); @@ -2111,6 +2110,18 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", /// types. void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange, llvm::ArrayRef); + + /// Adds a private clause variable to this operation, including its recipe. + void addPrivatization(MLIRContext *, mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe); + /// Adds a firstprivate clause variable to this operation, including its + /// recipe. + void addFirstPrivatization(MLIRContext *, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe); + /// Adds a reduction clause variable to this operation, including its + /// recipe. + void addReduction(MLIRContext *, mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe); }]; let assemblyFormat = [{ @@ -2119,10 +2130,12 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` | `async` `` custom($asyncOperands, type($asyncOperands), $asyncOperandsDeviceType, $asyncOnly) + | `firstprivate` `(` $firstprivateOperands `:` type($firstprivateOperands) `)` | `num_gangs` `(` custom($numGangs, type($numGangs), $numGangsDeviceType, $numGangsSegments) `)` | `num_workers` `(` custom($numWorkers, type($numWorkers), $numWorkersDeviceType) `)` + | `private` `(` $privateOperands `:` type($privateOperands) `)` | `vector_length` `(` custom($vectorLength, type($vectorLength), $vectorLengthDeviceType) `)` | `wait` `` custom($waitOperands, type($waitOperands), @@ -2130,6 +2143,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", $waitOnly) | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` + | `reduction` `(` $reductionOperands `:` type($reductionOperands) `)` ) $region attr-dict-with-keyword }]; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 7039bbe1d11ec..9235f89b7969a 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -2675,6 +2675,27 @@ LogicalResult acc::KernelsOp::verify() { return checkDataOperands(*this, getDataClauseOperands()); } +void acc::KernelsOp::addPrivatization(MLIRContext *context, + mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getPrivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getFirstprivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addReduction(MLIRContext *context, + mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getReductionOperandsMutable().append(op.getResult()); +} + void acc::KernelsOp::addNumWorkersOperand( MLIRContext *context, mlir::Value newValue, llvm::ArrayRef effectiveDeviceTypes) { diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index e004a88261c78..5a1c20bcf5a24 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -731,6 +731,59 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 // ----- +// Test acc.kernels with private and firstprivate operands, similar to acc.serial. + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init { +^bb0(%arg0: memref<10x10xf32>): + %1 = memref.alloc() : memref<10x10xf32> + acc.yield %1 : memref<10x10xf32> +} destroy { +^bb0(%arg0: memref<10x10xf32>): + memref.dealloc %arg0 : memref<10x10xf32> + acc.terminator +} + +acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %2 = memref.alloca() : memref<10xf32> + acc.yield %2 : memref<10xf32> +} copy { +^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>): + memref.copy %arg0, %arg1 : memref<10xf32> to memref<10xf32> + acc.terminator +} destroy { +^bb0(%arg0: memref<10xf32>): + acc.terminator +} + +func.func @testkernelspriv(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %priv_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + %priv_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + %firstp = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + acc.kernels firstprivate(%firstp : memref<10xf32>) private(%priv_a, %priv_c : memref<10xf32>, memref<10x10xf32>) { + } + return +} + +// CHECK-LABEL: func.func @testkernelspriv( +// CHECK: %[[PRIV_A:.*]] = acc.private varPtr(%{{.*}} : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK: %[[PRIV_C:.*]] = acc.private varPtr(%{{.*}} : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr(%{{.*}} : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: acc.kernels firstprivate(%[[FIRSTP]] : memref<10xf32>) private(%[[PRIV_A]], %[[PRIV_C]] : memref<10xf32>, memref<10x10xf32>) { +// CHECK-NEXT: } + +// ----- + func.func @testdataop(%a: memref, %b: memref, %c: memref) -> () { %ifCond = arith.constant true @@ -1602,6 +1655,35 @@ func.func @acc_reduc_test(%a : memref) -> () { // ----- +acc.reduction.recipe @reduction_add_memref_i64 : memref reduction_operator init { +^bb0(%arg0: memref): + %c0_i64 = arith.constant 0 : i64 + %alloca = memref.alloca() : memref + memref.store %c0_i64, %alloca[] : memref + acc.yield %alloca : memref +} combiner { +^bb0(%arg0: memref, %arg1: memref): + %0 = memref.load %arg0[] : memref + %1 = memref.load %arg1[] : memref + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref + acc.yield %arg0 : memref +} + +func.func @acc_kernels_reduc_test(%a : memref) -> () { + %reduction_a = acc.reduction varPtr(%a : memref) recipe(@reduction_add_memref_i64) -> memref + acc.kernels reduction(%reduction_a : memref) { + } + return +} + +// CHECK-LABEL: func.func @acc_kernels_reduc_test( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref) recipe(@reduction_add_memref_i64) -> memref +// CHECK-NEXT: acc.kernels reduction(%[[REDUCTION_A]] : memref) + +// ----- + func.func @testdeclareop(%a: memref, %b: memref, %c: memref) -> () { %0 = acc.copyin varPtr(%a : memref) -> memref // copyin(zero)