Skip to content

Commit 68d9c7b

Browse files
committed
[flang] Parsing and printing for fir.do_loop with private specifiers
1 parent 0389be9 commit 68d9c7b

File tree

2 files changed

+106
-19
lines changed

2 files changed

+106
-19
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2258,20 +2258,37 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
22582258
];
22592259

22602260
defvar opExtraClassDeclaration = [{
2261-
mlir::Value getInductionVar() { return getBody()->getArgument(0); }
22622261
mlir::OpBuilder getBodyBuilder() {
22632262
return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
22642263
}
2264+
2265+
/// Region argument accessors.
2266+
mlir::Value getInductionVar() { return getBody()->getArgument(0); }
22652267
mlir::Block::BlockArgListType getRegionIterArgs() {
2266-
return getBody()->getArguments().drop_front();
2268+
// 1 for skipping the induction variable.
2269+
return getBody()->getArguments().slice(1, getNumIterOperands());
22672270
}
2271+
mlir::Block::BlockArgListType getRegionPrivateArgs() {
2272+
return getBody()->getArguments().slice(1 + getNumIterOperands(),
2273+
numPrivateBlockArgs());
2274+
}
2275+
2276+
/// Operation operand accessors.
22682277
mlir::Operation::operand_range getIterOperands() {
22692278
return getOperands()
2270-
.drop_front(getNumControlOperands() + getNumReduceOperands());
2279+
.slice(getNumControlOperands() + getNumReduceOperands(),
2280+
getNumIterOperands());
22712281
}
22722282
llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
22732283
return getOperation()->getOpOperands()
2274-
.drop_front(getNumControlOperands() + getNumReduceOperands());
2284+
.slice(getNumControlOperands() + getNumReduceOperands(),
2285+
getNumIterOperands());
2286+
}
2287+
mlir::Operation::operand_range getPrivateOperands() {
2288+
return getOperands()
2289+
.slice(getNumControlOperands() + getNumReduceOperands()
2290+
+ getNumIterOperands(),
2291+
numPrivateBlockArgs());
22752292
}
22762293

22772294
void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2688,8 +2688,9 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
26882688

26892689
// Parse the optional initial iteration arguments.
26902690
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
2691-
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
26922691
llvm::SmallVector<mlir::Type> argTypes;
2692+
2693+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
26932694
bool prependCount = false;
26942695
regionArgs.push_back(inductionVariable);
26952696

@@ -2714,15 +2715,6 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
27142715
prependCount = true;
27152716
}
27162717

2717-
// Set the operandSegmentSizes attribute
2718-
result.addAttribute(getOperandSegmentSizeAttr(),
2719-
builder.getDenseI32ArrayAttr(
2720-
{1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
2721-
static_cast<int32_t>(iterOperands.size()), 0}));
2722-
2723-
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2724-
return mlir::failure();
2725-
27262718
// Induction variable.
27272719
if (prependCount)
27282720
result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name),
@@ -2731,15 +2723,77 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
27312723
argTypes.push_back(indexType);
27322724
// Loop carried variables
27332725
argTypes.append(result.types.begin(), result.types.end());
2734-
// Parse the body region.
2735-
auto *body = result.addRegion();
2726+
27362727
if (regionArgs.size() != argTypes.size())
27372728
return parser.emitError(
27382729
parser.getNameLoc(),
27392730
"mismatch in number of loop-carried values and defined values");
2731+
2732+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
2733+
if (succeeded(parser.parseOptionalKeyword("private"))) {
2734+
std::size_t oldArgTypesSize = argTypes.size();
2735+
if (failed(parser.parseLParen()))
2736+
return mlir::failure();
2737+
2738+
llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
2739+
if (failed(parser.parseCommaSeparatedList([&]() {
2740+
if (failed(parser.parseAttribute(privateSymbolVec.emplace_back())))
2741+
return mlir::failure();
2742+
2743+
if (parser.parseOperand(privateOperands.emplace_back()) ||
2744+
parser.parseArrow() ||
2745+
parser.parseArgument(regionArgs.emplace_back()))
2746+
return mlir::failure();
2747+
2748+
return mlir::success();
2749+
})))
2750+
return mlir::failure();
2751+
2752+
if (failed(parser.parseColon()))
2753+
return mlir::failure();
2754+
2755+
if (failed(parser.parseCommaSeparatedList([&]() {
2756+
if (failed(parser.parseType(argTypes.emplace_back())))
2757+
return mlir::failure();
2758+
2759+
return mlir::success();
2760+
})))
2761+
return mlir::failure();
2762+
2763+
if (regionArgs.size() != argTypes.size())
2764+
return parser.emitError(parser.getNameLoc(),
2765+
"mismatch in number of private arg and types");
2766+
2767+
if (failed(parser.parseRParen()))
2768+
return mlir::failure();
2769+
2770+
for (auto operandType : llvm::zip_equal(
2771+
privateOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
2772+
if (parser.resolveOperand(std::get<0>(operandType),
2773+
std::get<1>(operandType), result.operands))
2774+
return mlir::failure();
2775+
2776+
llvm::SmallVector<mlir::Attribute> symbolAttrs(privateSymbolVec.begin(),
2777+
privateSymbolVec.end());
2778+
result.addAttribute(getPrivateSymsAttrName(result.name),
2779+
builder.getArrayAttr(symbolAttrs));
2780+
}
2781+
2782+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2783+
return mlir::failure();
2784+
2785+
// Set the operandSegmentSizes attribute
2786+
result.addAttribute(getOperandSegmentSizeAttr(),
2787+
builder.getDenseI32ArrayAttr(
2788+
{1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
2789+
static_cast<int32_t>(iterOperands.size()),
2790+
static_cast<int32_t>(privateOperands.size())}));
2791+
27402792
for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
27412793
regionArgs[i].type = argTypes[i];
27422794

2795+
// Parse the body region.
2796+
auto *body = result.addRegion();
27432797
if (parser.parseRegion(*body, regionArgs))
27442798
return mlir::failure();
27452799

@@ -2833,9 +2887,25 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
28332887
p << " -> " << getResultTypes();
28342888
printBlockTerminators = true;
28352889
}
2836-
p.printOptionalAttrDictWithKeyword(
2837-
(*this)->getAttrs(),
2838-
{"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
2890+
2891+
if (numPrivateBlockArgs() > 0) {
2892+
p << " private(";
2893+
llvm::interleaveComma(llvm::zip_equal(getPrivateSymsAttr(),
2894+
getPrivateVars(),
2895+
getRegionPrivateArgs()),
2896+
p, [&](auto it) {
2897+
p << std::get<0>(it) << " " << std::get<1>(it)
2898+
<< " -> " << std::get<2>(it);
2899+
});
2900+
p << " : ";
2901+
llvm::interleaveComma(getPrivateVars(), p,
2902+
[&](auto it) { p << it.getType(); });
2903+
p << ")";
2904+
}
2905+
2906+
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
2907+
{"unordered", "finalValue", "reduceAttrs",
2908+
"operandSegmentSizes", "private_syms"});
28392909
p << ' ';
28402910
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
28412911
printBlockTerminators);

0 commit comments

Comments
 (0)