Skip to content

Commit 261d867

Browse files
committed
[flang] Parsing and printing for fir.do_concurrent.loop with private specifiers
1 parent 1b7ea4f commit 261d867

File tree

3 files changed

+194
-18
lines changed

3 files changed

+194
-18
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3588,10 +3588,32 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
35883588
let hasVerifier = 1;
35893589

35903590
defvar opExtraClassDeclaration = [{
3591+
unsigned getNumInductionVars() { return getLowerBound().size(); }
3592+
3593+
unsigned getNumPrivateOperands() { return getPrivateVars().size(); }
3594+
3595+
mlir::Block::BlockArgListType getInductionVars() {
3596+
return getBody()->getArguments().slice(0, getNumInductionVars());
3597+
}
3598+
3599+
mlir::Block::BlockArgListType getRegionPrivateArgs() {
3600+
return getBody()->getArguments().slice(getNumInductionVars(),
3601+
getNumPrivateOperands());
3602+
}
3603+
3604+
/// Number of operands controlling the loop
3605+
unsigned getNumControlOperands() { return getLowerBound().size() * 3; }
3606+
35913607
// Get Number of reduction operands
35923608
unsigned getNumReduceOperands() {
35933609
return getReduceOperands().size();
35943610
}
3611+
3612+
mlir::Operation::operand_range getPrivateOperands() {
3613+
return getOperands()
3614+
.slice(getNumControlOperands() + getNumReduceOperands(),
3615+
getNumPrivateOperands());
3616+
}
35953617
}];
35963618

35973619
let extraClassDeclaration =

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 95 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4886,29 +4886,33 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
48864886
mlir::OperationState &result) {
48874887
auto &builder = parser.getBuilder();
48884888
// Parse an opening `(` followed by induction variables followed by `)`
4889-
llvm::SmallVector<mlir::OpAsmParser::Argument, 4> ivs;
4890-
if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren))
4889+
llvm::SmallVector<mlir::OpAsmParser::Argument, 4> regionArgs;
4890+
4891+
if (parser.parseArgumentList(regionArgs, mlir::OpAsmParser::Delimiter::Paren))
48914892
return mlir::failure();
48924893

4894+
llvm::SmallVector<mlir::Type> argTypes(regionArgs.size(),
4895+
builder.getIndexType());
4896+
48934897
// Parse loop bounds.
48944898
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower;
48954899
if (parser.parseEqual() ||
4896-
parser.parseOperandList(lower, ivs.size(),
4900+
parser.parseOperandList(lower, regionArgs.size(),
48974901
mlir::OpAsmParser::Delimiter::Paren) ||
48984902
parser.resolveOperands(lower, builder.getIndexType(), result.operands))
48994903
return mlir::failure();
49004904

49014905
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper;
49024906
if (parser.parseKeyword("to") ||
4903-
parser.parseOperandList(upper, ivs.size(),
4907+
parser.parseOperandList(upper, regionArgs.size(),
49044908
mlir::OpAsmParser::Delimiter::Paren) ||
49054909
parser.resolveOperands(upper, builder.getIndexType(), result.operands))
49064910
return mlir::failure();
49074911

49084912
// Parse step values.
49094913
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps;
49104914
if (parser.parseKeyword("step") ||
4911-
parser.parseOperandList(steps, ivs.size(),
4915+
parser.parseOperandList(steps, regionArgs.size(),
49124916
mlir::OpAsmParser::Delimiter::Paren) ||
49134917
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
49144918
return mlir::failure();
@@ -4939,20 +4943,72 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
49394943
builder.getArrayAttr(arrayAttr));
49404944
}
49414945

4942-
// Now parse the body.
4943-
mlir::Region *body = result.addRegion();
4944-
for (auto &iv : ivs)
4945-
iv.type = builder.getIndexType();
4946-
if (parser.parseRegion(*body, ivs))
4947-
return mlir::failure();
4946+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
4947+
if (succeeded(parser.parseOptionalKeyword("private"))) {
4948+
std::size_t oldArgTypesSize = argTypes.size();
4949+
if (failed(parser.parseLParen()))
4950+
return mlir::failure();
4951+
4952+
llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
4953+
if (failed(parser.parseCommaSeparatedList([&]() {
4954+
if (failed(parser.parseAttribute(privateSymbolVec.emplace_back())))
4955+
return mlir::failure();
4956+
4957+
if (parser.parseOperand(privateOperands.emplace_back()) ||
4958+
parser.parseArrow() ||
4959+
parser.parseArgument(regionArgs.emplace_back()))
4960+
return mlir::failure();
4961+
4962+
return mlir::success();
4963+
})))
4964+
return mlir::failure();
4965+
4966+
if (failed(parser.parseColon()))
4967+
return mlir::failure();
4968+
4969+
if (failed(parser.parseCommaSeparatedList([&]() {
4970+
if (failed(parser.parseType(argTypes.emplace_back())))
4971+
return mlir::failure();
4972+
4973+
return mlir::success();
4974+
})))
4975+
return mlir::failure();
4976+
4977+
if (regionArgs.size() != argTypes.size())
4978+
return parser.emitError(parser.getNameLoc(),
4979+
"mismatch in number of private arg and types");
4980+
4981+
if (failed(parser.parseRParen()))
4982+
return mlir::failure();
4983+
4984+
for (auto operandType : llvm::zip_equal(
4985+
privateOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
4986+
if (parser.resolveOperand(std::get<0>(operandType),
4987+
std::get<1>(operandType), result.operands))
4988+
return mlir::failure();
4989+
4990+
llvm::SmallVector<mlir::Attribute> symbolAttrs(privateSymbolVec.begin(),
4991+
privateSymbolVec.end());
4992+
result.addAttribute(getPrivateSymsAttrName(result.name),
4993+
builder.getArrayAttr(symbolAttrs));
4994+
}
49484995

49494996
// Set `operandSegmentSizes` attribute.
49504997
result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
49514998
builder.getDenseI32ArrayAttr(
49524999
{static_cast<int32_t>(lower.size()),
49535000
static_cast<int32_t>(upper.size()),
49545001
static_cast<int32_t>(steps.size()),
4955-
static_cast<int32_t>(reduceOperands.size()), 0}));
5002+
static_cast<int32_t>(reduceOperands.size()),
5003+
static_cast<int32_t>(privateOperands.size())}));
5004+
5005+
// Now parse the body.
5006+
for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes))
5007+
arg.type = type;
5008+
5009+
mlir::Region *body = result.addRegion();
5010+
if (parser.parseRegion(*body, regionArgs))
5011+
return mlir::failure();
49565012

49575013
// Parse attributes.
49585014
if (parser.parseOptionalAttrDict(result.attributes))
@@ -4962,8 +5018,9 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
49625018
}
49635019

49645020
void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
4965-
p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
4966-
<< ") to (" << getUpperBound() << ") step (" << getStep() << ")";
5021+
p << " (" << getBody()->getArguments().slice(0, getNumInductionVars())
5022+
<< ") = (" << getLowerBound() << ") to (" << getUpperBound() << ") step ("
5023+
<< getStep() << ")";
49675024

49685025
if (!getReduceOperands().empty()) {
49695026
p << " reduce(";
@@ -4976,12 +5033,28 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
49765033
p << ')';
49775034
}
49785035

5036+
if (!getPrivateVars().empty()) {
5037+
p << " private(";
5038+
llvm::interleaveComma(llvm::zip_equal(getPrivateSymsAttr(),
5039+
getPrivateVars(),
5040+
getRegionPrivateArgs()),
5041+
p, [&](auto it) {
5042+
p << std::get<0>(it) << " " << std::get<1>(it)
5043+
<< " -> " << std::get<2>(it);
5044+
});
5045+
p << " : ";
5046+
llvm::interleaveComma(getPrivateVars(), p,
5047+
[&](auto it) { p << it.getType(); });
5048+
p << ")";
5049+
}
5050+
49795051
p << ' ';
49805052
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
49815053
p.printOptionalAttrDict(
49825054
(*this)->getAttrs(),
49835055
/*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
4984-
DoConcurrentLoopOp::getReduceAttrsAttrName()});
5056+
DoConcurrentLoopOp::getReduceAttrsAttrName(),
5057+
DoConcurrentLoopOp::getPrivateSymsAttrName()});
49855058
}
49865059

49875060
llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions() {
@@ -4992,6 +5065,7 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
49925065
mlir::Operation::operand_range lbValues = getLowerBound();
49935066
mlir::Operation::operand_range ubValues = getUpperBound();
49945067
mlir::Operation::operand_range stepValues = getStep();
5068+
mlir::Operation::operand_range privateVars = getPrivateVars();
49955069

49965070
if (lbValues.empty())
49975071
return emitOpError(
@@ -5005,11 +5079,13 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
50055079
// Check that the body defines the same number of block arguments as the
50065080
// number of tuple elements in step.
50075081
mlir::Block *body = getBody();
5008-
if (body->getNumArguments() != stepValues.size())
5082+
unsigned numIndVarArgs = body->getNumArguments() - privateVars.size();
5083+
5084+
if (numIndVarArgs != stepValues.size())
50095085
return emitOpError() << "expects the same number of induction variables: "
50105086
<< body->getNumArguments()
50115087
<< " as bound and step values: " << stepValues.size();
5012-
for (auto arg : body->getArguments())
5088+
for (auto arg : body->getArguments().slice(0, numIndVarArgs))
50135089
if (!arg.getType().isIndex())
50145090
return emitOpError(
50155091
"expects arguments for the induction variable to be of index type");
@@ -5024,7 +5100,8 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
50245100

50255101
std::optional<llvm::SmallVector<mlir::Value>>
50265102
fir::DoConcurrentLoopOp::getLoopInductionVars() {
5027-
return llvm::SmallVector<mlir::Value>{getBody()->getArguments()};
5103+
return llvm::SmallVector<mlir::Value>{
5104+
getBody()->getArguments().slice(0, getLowerBound().size())};
50285105
}
50295106

50305107
//===----------------------------------------------------------------------===//

flang/test/Fir/do_concurrent.fir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,80 @@ func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
9090
// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
9191
// CHECK: }
9292
// CHECK: }
93+
94+
95+
omp.private {type = private} @local_privatizer : i32
96+
97+
omp.private {type = firstprivate} @local_init_privatizer : i32 copy {
98+
^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<i32>):
99+
%0 = fir.load %arg0 : !fir.ref<i32>
100+
fir.store %0 to %arg1 : !fir.ref<i32>
101+
omp.yield(%arg1 : !fir.ref<i32>)
102+
}
103+
104+
func.func @_QPdo_concurrent() {
105+
%3 = fir.alloca i32 {bindc_name = "local_init_var", uniq_name = "_QFdo_concurrentElocal_init_var"}
106+
%4:2 = hlfir.declare %3 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
107+
%5 = fir.alloca i32 {bindc_name = "local_var", uniq_name = "_QFdo_concurrentElocal_var"}
108+
%6:2 = hlfir.declare %5 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
109+
%c1 = arith.constant 1 : index
110+
%c10 = arith.constant 1 : index
111+
fir.do_concurrent {
112+
%9 = fir.alloca i32 {bindc_name = "i"}
113+
%10:2 = hlfir.declare %9 {uniq_name = "_QFdo_concurrentEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
114+
fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) private(@local_privatizer %6#0 -> %arg1, @local_init_privatizer %4#0 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) {
115+
%11 = fir.convert %arg0 : (index) -> i32
116+
fir.store %11 to %10#0 : !fir.ref<i32>
117+
%13:2 = hlfir.declare %arg1 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
118+
%15:2 = hlfir.declare %arg2 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
119+
%17 = fir.load %10#0 : !fir.ref<i32>
120+
%c5_i32 = arith.constant 5 : i32
121+
%18 = arith.cmpi slt, %17, %c5_i32 : i32
122+
fir.if %18 {
123+
%c42_i32 = arith.constant 42 : i32
124+
hlfir.assign %c42_i32 to %13#0 : i32, !fir.ref<i32>
125+
} else {
126+
%c84_i32 = arith.constant 84 : i32
127+
hlfir.assign %c84_i32 to %15#0 : i32, !fir.ref<i32>
128+
}
129+
}
130+
}
131+
return
132+
}
133+
134+
// CHECK: omp.private {type = private} @[[LOCAL_PRIV_SYM:local_privatizer]] : i32
135+
136+
// CHECK: omp.private {type = firstprivate} @[[LOCAL_INIT_PRIV_SYM:local_init_privatizer]] : i32
137+
138+
// CHECK-LABEL: func.func @_QPdo_concurrent() {
139+
// CHECK: %[[LOC_INIT_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_init_var", {{.*}}}
140+
// CHECK: %[[LOC_INIT_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ALLOC]]
141+
142+
// CHECK: %[[LOC_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_var", {{.*}}}
143+
// CHECK: %[[LOC_DECL:.*]]:2 = hlfir.declare %[[LOC_ALLOC]]
144+
145+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
146+
// CHECK: %[[C10:.*]] = arith.constant 1 : index
147+
148+
// CHECK: fir.do_concurrent {
149+
// CHECK: %[[DC_I_ALLOC:.*]] = fir.alloca i32 {bindc_name = "i"}
150+
// CHECK: %[[DC_I_DECL:.*]]:2 = hlfir.declare %[[DC_I_ALLOC]]
151+
152+
// CHECK: fir.do_concurrent.loop (%[[IV:.*]]) = (%[[C1]]) to (%[[C10]]) step (%[[C1]]) private(@[[LOCAL_PRIV_SYM]] %[[LOC_DECL]]#0 -> %[[LOC_ARG:.*]], @[[LOCAL_INIT_PRIV_SYM]] %[[LOC_INIT_DECL]]#0 -> %[[LOC_INIT_ARG:.*]] : !fir.ref<i32>, !fir.ref<i32>) {
153+
// CHECK: %[[IV_CVT:.*]] = fir.convert %[[IV]] : (index) -> i32
154+
// CHECK: fir.store %[[IV_CVT]] to %[[DC_I_DECL]]#0 : !fir.ref<i32>
155+
156+
// CHECK: %[[LOC_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_ARG]]
157+
// CHECK: %[[LOC_INIT_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ARG]]
158+
159+
// CHECK: fir.if %{{.*}} {
160+
// CHECK: %[[C42:.*]] = arith.constant 42 : i32
161+
// CHECK: hlfir.assign %[[C42]] to %[[LOC_PRIV_DECL]]#0 : i32, !fir.ref<i32>
162+
// CHECK: } else {
163+
// CHECK: %[[C84:.*]] = arith.constant 84 : i32
164+
// CHECK: hlfir.assign %[[C84]] to %[[LOC_INIT_PRIV_DECL]]#0 : i32, !fir.ref<i32>
165+
// CHECK: }
166+
// CHECK: }
167+
// CHECK: }
168+
// CHECK: return
169+
// CHECK: }

0 commit comments

Comments
 (0)