diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index d9f38259c0ace..e305e2fbde5b1 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -1114,9 +1114,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", UnitAttr:$selfAttr, Variadic:$reductionOperands, OptionalAttr:$reductionRecipes, - Variadic:$gangPrivateOperands, + Variadic:$privateOperands, OptionalAttr:$privatizations, - Variadic:$gangFirstPrivateOperands, + Variadic:$firstprivateOperands, OptionalAttr:$firstprivatizations, Variadic:$dataClauseOperands, OptionalAttr:$defaultAttr, @@ -1134,8 +1134,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", CArg<"mlir::Value", "{}">:$ifCond, CArg<"mlir::Value", "{}">:$selfCond, CArg<"mlir::ValueRange", "{}">:$reductionOperands, - CArg<"mlir::ValueRange", "{}">:$gangPrivateOperands, - CArg<"mlir::ValueRange", "{}">:$gangFirstPrivateOperands, + CArg<"mlir::ValueRange", "{}">:$privateOperands, + CArg<"mlir::ValueRange", "{}">:$firstprivateOperands, CArg<"mlir::ValueRange", "{}">:$dataClauseOperands)>]; let extraClassDeclaration = [{ @@ -1145,6 +1145,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", /// The i-th data operand passed. Value getDataOperand(unsigned i); + /// Used to retrieve the block inside the op's region. + Block &getBody() { return getRegion().front(); } + /// Return true if the op has the async attribute for the /// mlir::acc::DeviceType::None device_type. bool hasAsyncOnly(); @@ -1202,15 +1205,15 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` | `async` `(` custom($asyncOperands, type($asyncOperands), $asyncOperandsDeviceType) `)` - | `firstprivate` `(` custom($gangFirstPrivateOperands, - type($gangFirstPrivateOperands), $firstprivatizations) + | `firstprivate` `(` custom($firstprivateOperands, + type($firstprivateOperands), $firstprivatizations) `)` | `num_gangs` `(` custom($numGangs, type($numGangs), $numGangsDeviceType, $numGangsSegments) `)` | `num_workers` `(` custom($numWorkers, type($numWorkers), $numWorkersDeviceType) `)` | `private` `(` custom( - $gangPrivateOperands, type($gangPrivateOperands), $privatizations) + $privateOperands, type($privateOperands), $privatizations) `)` | `vector_length` `(` custom($vectorLength, type($vectorLength), $vectorLengthDeviceType) `)` @@ -1271,9 +1274,9 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", UnitAttr:$selfAttr, Variadic:$reductionOperands, OptionalAttr:$reductionRecipes, - Variadic:$gangPrivateOperands, + Variadic:$privateOperands, OptionalAttr:$privatizations, - Variadic:$gangFirstPrivateOperands, + Variadic:$firstprivateOperands, OptionalAttr:$firstprivatizations, Variadic:$dataClauseOperands, OptionalAttr:$defaultAttr, @@ -1288,6 +1291,9 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", /// The i-th data operand passed. Value getDataOperand(unsigned i); + /// Used to retrieve the block inside the op's region. + Block &getBody() { return getRegion().front(); } + /// Return true if the op has the async attribute for the /// mlir::acc::DeviceType::None device_type. bool hasAsyncOnly(); @@ -1326,11 +1332,11 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` | `async` `(` custom($asyncOperands, type($asyncOperands), $asyncOperandsDeviceType) `)` - | `firstprivate` `(` custom($gangFirstPrivateOperands, - type($gangFirstPrivateOperands), $firstprivatizations) + | `firstprivate` `(` custom($firstprivateOperands, + type($firstprivateOperands), $firstprivatizations) `)` | `private` `(` custom( - $gangPrivateOperands, type($gangPrivateOperands), $privatizations) + $privateOperands, type($privateOperands), $privatizations) `)` | `wait` `` custom($waitOperands, type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum, @@ -1410,6 +1416,9 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", /// The i-th data operand passed. Value getDataOperand(unsigned i); + /// Used to retrieve the block inside the op's region. + Block &getBody() { return getRegion().front(); } + /// Return true if the op has the async attribute for the /// mlir::acc::DeviceType::None device_type. bool hasAsyncOnly(); @@ -1824,6 +1833,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", /// The i-th data operand passed. Value getDataOperand(unsigned i); + /// Used to retrieve the block inside the op's region. Block &getBody() { return getLoopRegions().front()->front(); } /// Return true if the op has the auto attribute for the diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 919a0853fb604..280260e0485bb 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -730,8 +730,8 @@ checkSymOperandList(Operation *op, std::optional attributes, } unsigned ParallelOp::getNumDataOperands() { - return getReductionOperands().size() + getGangPrivateOperands().size() + - getGangFirstPrivateOperands().size() + getDataClauseOperands().size(); + return getReductionOperands().size() + getPrivateOperands().size() + + getFirstprivateOperands().size() + getDataClauseOperands().size(); } Value ParallelOp::getDataOperand(unsigned i) { @@ -783,9 +783,13 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch( LogicalResult acc::ParallelOp::verify() { if (failed(checkSymOperandList( - *this, getPrivatizations(), getGangPrivateOperands(), "private", + *this, getPrivatizations(), getPrivateOperands(), "private", "privatizations", /*checkOperandType=*/false))) return failure(); + if (failed(checkSymOperandList( + *this, getFirstprivatizations(), getFirstprivateOperands(), + "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + return failure(); if (failed(checkSymOperandList( *this, getReductionRecipes(), getReductionOperands(), "reduction", "reductions", false))) @@ -1361,8 +1365,8 @@ printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, //===----------------------------------------------------------------------===// unsigned SerialOp::getNumDataOperands() { - return getReductionOperands().size() + getGangPrivateOperands().size() + - getGangFirstPrivateOperands().size() + getDataClauseOperands().size(); + return getReductionOperands().size() + getPrivateOperands().size() + + getFirstprivateOperands().size() + getDataClauseOperands().size(); } Value SerialOp::getDataOperand(unsigned i) { @@ -1420,9 +1424,13 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { LogicalResult acc::SerialOp::verify() { if (failed(checkSymOperandList( - *this, getPrivatizations(), getGangPrivateOperands(), "private", + *this, getPrivatizations(), getPrivateOperands(), "private", "privatizations", /*checkOperandType=*/false))) return failure(); + if (failed(checkSymOperandList( + *this, getFirstprivatizations(), getFirstprivateOperands(), + "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + return failure(); if (failed(checkSymOperandList( *this, getReductionRecipes(), getReductionOperands(), "reduction", "reductions", false))) diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp index 4038e333adb8b..026b309ce4969 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp @@ -83,8 +83,8 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { !std::is_same_v && !std::is_same_v) { collectPtrs(op.getReductionOperands(), values, hostToDevice); - collectPtrs(op.getGangPrivateOperands(), values, hostToDevice); - collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice); + collectPtrs(op.getPrivateOperands(), values, hostToDevice); + collectPtrs(op.getFirstprivateOperands(), values, hostToDevice); } }