Skip to content

Commit 93b2030

Browse files
committed
add reduction
1 parent 77e571e commit 93b2030

File tree

4 files changed

+26
-2
lines changed

4 files changed

+26
-2
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,8 +3024,7 @@ static Op createComputeOp(
30243024
}
30253025
addOperand(operands, operandSegments, ifCond);
30263026
addOperand(operands, operandSegments, selfCond);
3027-
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>)
3028-
addOperands(operands, operandSegments, reductionOperands);
3027+
addOperands(operands, operandSegments, reductionOperands);
30293028
addOperands(operands, operandSegments, privateOperands);
30303029
addOperands(operands, operandSegments, firstprivateOperands);
30313030
addOperands(operands, operandSegments, dataClauseOperands);

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,6 +2018,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
20182018
Variadic<IntOrIndex>:$vectorLength,
20192019
OptionalAttr<DeviceTypeArrayAttr>:$vectorLengthDeviceType,
20202020
Optional<I1>:$ifCond, Optional<I1>:$selfCond, UnitAttr:$selfAttr,
2021+
Variadic<OpenACC_AnyPointerOrMappableType>:$reductionOperands,
20212022
Variadic<OpenACC_AnyPointerOrMappableType>:$privateOperands,
20222023
Variadic<OpenACC_AnyPointerOrMappableType>:$firstprivateOperands,
20232024
Variadic<OpenACC_AnyPointerOrMappableType>:$dataClauseOperands,
@@ -2117,6 +2118,10 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
21172118
/// recipe.
21182119
void addFirstPrivatization(MLIRContext *, mlir::acc::FirstprivateOp op,
21192120
mlir::acc::FirstprivateRecipeOp recipe);
2121+
/// Adds a reduction clause variable to this operation, including its
2122+
/// recipe.
2123+
void addReduction(MLIRContext *, mlir::acc::ReductionOp op,
2124+
mlir::acc::ReductionRecipeOp recipe);
21202125
}];
21212126

21222127
let assemblyFormat = [{
@@ -2138,6 +2143,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
21382143
$waitOnly)
21392144
| `self` `(` $selfCond `)`
21402145
| `if` `(` $ifCond `)`
2146+
| `reduction` `(` $reductionOperands `:` type($reductionOperands) `)`
21412147
)
21422148
$region attr-dict-with-keyword
21432149
}];

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,6 +2689,13 @@ void acc::KernelsOp::addFirstPrivatization(
26892689
getFirstprivateOperandsMutable().append(op.getResult());
26902690
}
26912691

2692+
void acc::KernelsOp::addReduction(MLIRContext *context,
2693+
mlir::acc::ReductionOp op,
2694+
mlir::acc::ReductionRecipeOp recipe) {
2695+
op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2696+
getReductionOperandsMutable().append(op.getResult());
2697+
}
2698+
26922699
void acc::KernelsOp::addNumWorkersOperand(
26932700
MLIRContext *context, mlir::Value newValue,
26942701
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {

mlir/test/Dialect/OpenACC/ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,18 @@ func.func @acc_reduc_test(%a : memref<i64>) -> () {
16531653
// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
16541654
// CHECK-NEXT: acc.serial reduction(%[[REDUCTION_A]] : memref<i64>)
16551655

1656+
func.func @acc_kernels_reduc_test(%a : memref<i64>) -> () {
1657+
%reduction_a = acc.reduction varPtr(%a : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
1658+
acc.kernels reduction(%reduction_a : memref<i64>) {
1659+
}
1660+
return
1661+
}
1662+
1663+
// CHECK-LABEL: func.func @acc_kernels_reduc_test(
1664+
// CHECK-SAME: %[[ARG0:.*]]: memref<i64>)
1665+
// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
1666+
// CHECK-NEXT: acc.kernels reduction(%[[REDUCTION_A]] : memref<i64>)
1667+
16561668
// -----
16571669

16581670
func.func @testdeclareop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {

0 commit comments

Comments
 (0)