Skip to content

Commit bb165a2

Browse files
committed
[flang] Parsing and printing for fir.do_loop with private specifiers
1 parent 4dd5222 commit bb165a2

File tree

2 files changed

+106
-19
lines changed

2 files changed

+106
-19
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,20 +2203,37 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
22032203
];
22042204

22052205
defvar opExtraClassDeclaration = [{
2206-
mlir::Value getInductionVar() { return getBody()->getArgument(0); }
22072206
mlir::OpBuilder getBodyBuilder() {
22082207
return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
22092208
}
2209+
2210+
/// Region argument accessors.
2211+
mlir::Value getInductionVar() { return getBody()->getArgument(0); }
22102212
mlir::Block::BlockArgListType getRegionIterArgs() {
2211-
return getBody()->getArguments().drop_front();
2213+
// 1 for skipping the induction variable.
2214+
return getBody()->getArguments().slice(1, getNumIterOperands());
22122215
}
2216+
mlir::Block::BlockArgListType getRegionPrivateArgs() {
2217+
return getBody()->getArguments().slice(1 + getNumIterOperands(),
2218+
numPrivateBlockArgs());
2219+
}
2220+
2221+
/// Operation operand accessors.
22132222
mlir::Operation::operand_range getIterOperands() {
22142223
return getOperands()
2215-
.drop_front(getNumControlOperands() + getNumReduceOperands());
2224+
.slice(getNumControlOperands() + getNumReduceOperands(),
2225+
getNumIterOperands());
22162226
}
22172227
llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
22182228
return getOperation()->getOpOperands()
2219-
.drop_front(getNumControlOperands() + getNumReduceOperands());
2229+
.slice(getNumControlOperands() + getNumReduceOperands(),
2230+
getNumIterOperands());
2231+
}
2232+
mlir::Operation::operand_range getPrivateOperands() {
2233+
return getOperands()
2234+
.slice(getNumControlOperands() + getNumReduceOperands()
2235+
+ getNumIterOperands(),
2236+
numPrivateBlockArgs());
22202237
}
22212238

22222239
void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2563,8 +2563,9 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
25632563

25642564
// Parse the optional initial iteration arguments.
25652565
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
2566-
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
25672566
llvm::SmallVector<mlir::Type> argTypes;
2567+
2568+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
25682569
bool prependCount = false;
25692570
regionArgs.push_back(inductionVariable);
25702571

@@ -2589,15 +2590,6 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
25892590
prependCount = true;
25902591
}
25912592

2592-
// Set the operandSegmentSizes attribute
2593-
result.addAttribute(getOperandSegmentSizeAttr(),
2594-
builder.getDenseI32ArrayAttr(
2595-
{1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
2596-
static_cast<int32_t>(iterOperands.size()), 0}));
2597-
2598-
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2599-
return mlir::failure();
2600-
26012593
// Induction variable.
26022594
if (prependCount)
26032595
result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name),
@@ -2606,15 +2598,77 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
26062598
argTypes.push_back(indexType);
26072599
// Loop carried variables
26082600
argTypes.append(result.types.begin(), result.types.end());
2609-
// Parse the body region.
2610-
auto *body = result.addRegion();
2601+
26112602
if (regionArgs.size() != argTypes.size())
26122603
return parser.emitError(
26132604
parser.getNameLoc(),
26142605
"mismatch in number of loop-carried values and defined values");
2606+
2607+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
2608+
if (succeeded(parser.parseOptionalKeyword("private"))) {
2609+
std::size_t oldArgTypesSize = argTypes.size();
2610+
if (failed(parser.parseLParen()))
2611+
return mlir::failure();
2612+
2613+
llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
2614+
if (failed(parser.parseCommaSeparatedList([&]() {
2615+
if (failed(parser.parseAttribute(privateSymbolVec.emplace_back())))
2616+
return mlir::failure();
2617+
2618+
if (parser.parseOperand(privateOperands.emplace_back()) ||
2619+
parser.parseArrow() ||
2620+
parser.parseArgument(regionArgs.emplace_back()))
2621+
return mlir::failure();
2622+
2623+
return mlir::success();
2624+
})))
2625+
return mlir::failure();
2626+
2627+
if (failed(parser.parseColon()))
2628+
return mlir::failure();
2629+
2630+
if (failed(parser.parseCommaSeparatedList([&]() {
2631+
if (failed(parser.parseType(argTypes.emplace_back())))
2632+
return mlir::failure();
2633+
2634+
return mlir::success();
2635+
})))
2636+
return mlir::failure();
2637+
2638+
if (regionArgs.size() != argTypes.size())
2639+
return parser.emitError(parser.getNameLoc(),
2640+
"mismatch in number of private arg and types");
2641+
2642+
if (failed(parser.parseRParen()))
2643+
return mlir::failure();
2644+
2645+
for (auto operandType : llvm::zip_equal(
2646+
privateOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
2647+
if (parser.resolveOperand(std::get<0>(operandType),
2648+
std::get<1>(operandType), result.operands))
2649+
return mlir::failure();
2650+
2651+
llvm::SmallVector<mlir::Attribute> symbolAttrs(privateSymbolVec.begin(),
2652+
privateSymbolVec.end());
2653+
result.addAttribute(getPrivateSymsAttrName(result.name),
2654+
builder.getArrayAttr(symbolAttrs));
2655+
}
2656+
2657+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2658+
return mlir::failure();
2659+
2660+
// Set the operandSegmentSizes attribute
2661+
result.addAttribute(getOperandSegmentSizeAttr(),
2662+
builder.getDenseI32ArrayAttr(
2663+
{1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
2664+
static_cast<int32_t>(iterOperands.size()),
2665+
static_cast<int32_t>(privateOperands.size())}));
2666+
26152667
for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
26162668
regionArgs[i].type = argTypes[i];
26172669

2670+
// Parse the body region.
2671+
auto *body = result.addRegion();
26182672
if (parser.parseRegion(*body, regionArgs))
26192673
return mlir::failure();
26202674

@@ -2708,9 +2762,25 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
27082762
p << " -> " << getResultTypes();
27092763
printBlockTerminators = true;
27102764
}
2711-
p.printOptionalAttrDictWithKeyword(
2712-
(*this)->getAttrs(),
2713-
{"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
2765+
2766+
if (numPrivateBlockArgs() > 0) {
2767+
p << " private(";
2768+
llvm::interleaveComma(llvm::zip_equal(getPrivateSymsAttr(),
2769+
getPrivateVars(),
2770+
getRegionPrivateArgs()),
2771+
p, [&](auto it) {
2772+
p << std::get<0>(it) << " " << std::get<1>(it)
2773+
<< " -> " << std::get<2>(it);
2774+
});
2775+
p << " : ";
2776+
llvm::interleaveComma(getPrivateVars(), p,
2777+
[&](auto it) { p << it.getType(); });
2778+
p << ")";
2779+
}
2780+
2781+
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
2782+
{"unordered", "finalValue", "reduceAttrs",
2783+
"operandSegmentSizes", "private_syms"});
27142784
p << ' ';
27152785
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
27162786
printBlockTerminators);

0 commit comments

Comments
 (0)