@@ -2688,8 +2688,9 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
26882688
26892689 // Parse the optional initial iteration arguments.
26902690 llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
2691- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
26922691 llvm::SmallVector<mlir::Type> argTypes;
2692+
2693+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
26932694 bool prependCount = false;
26942695 regionArgs.push_back(inductionVariable);
26952696
@@ -2714,15 +2715,6 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
27142715 prependCount = true;
27152716 }
27162717
2717- // Set the operandSegmentSizes attribute
2718- result.addAttribute(getOperandSegmentSizeAttr(),
2719- builder.getDenseI32ArrayAttr(
2720- {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
2721- static_cast<int32_t>(iterOperands.size()), 0}));
2722-
2723- if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2724- return mlir::failure();
2725-
27262718 // Induction variable.
27272719 if (prependCount)
27282720 result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name),
@@ -2731,15 +2723,77 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
27312723 argTypes.push_back(indexType);
27322724 // Loop carried variables
27332725 argTypes.append(result.types.begin(), result.types.end());
2734- // Parse the body region.
2735- auto *body = result.addRegion();
2726+
27362727 if (regionArgs.size() != argTypes.size())
27372728 return parser.emitError(
27382729 parser.getNameLoc(),
27392730 "mismatch in number of loop-carried values and defined values");
2731+
2732+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
2733+ if (succeeded(parser.parseOptionalKeyword("private"))) {
2734+ std::size_t oldArgTypesSize = argTypes.size();
2735+ if (failed(parser.parseLParen()))
2736+ return mlir::failure();
2737+
2738+ llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
2739+ if (failed(parser.parseCommaSeparatedList([&]() {
2740+ if (failed(parser.parseAttribute(privateSymbolVec.emplace_back())))
2741+ return mlir::failure();
2742+
2743+ if (parser.parseOperand(privateOperands.emplace_back()) ||
2744+ parser.parseArrow() ||
2745+ parser.parseArgument(regionArgs.emplace_back()))
2746+ return mlir::failure();
2747+
2748+ return mlir::success();
2749+ })))
2750+ return mlir::failure();
2751+
2752+ if (failed(parser.parseColon()))
2753+ return mlir::failure();
2754+
2755+ if (failed(parser.parseCommaSeparatedList([&]() {
2756+ if (failed(parser.parseType(argTypes.emplace_back())))
2757+ return mlir::failure();
2758+
2759+ return mlir::success();
2760+ })))
2761+ return mlir::failure();
2762+
2763+ if (regionArgs.size() != argTypes.size())
2764+ return parser.emitError(parser.getNameLoc(),
2765+ "mismatch in number of private arg and types");
2766+
2767+ if (failed(parser.parseRParen()))
2768+ return mlir::failure();
2769+
2770+ for (auto operandType : llvm::zip_equal(
2771+ privateOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
2772+ if (parser.resolveOperand(std::get<0>(operandType),
2773+ std::get<1>(operandType), result.operands))
2774+ return mlir::failure();
2775+
2776+ llvm::SmallVector<mlir::Attribute> symbolAttrs(privateSymbolVec.begin(),
2777+ privateSymbolVec.end());
2778+ result.addAttribute(getPrivateSymsAttrName(result.name),
2779+ builder.getArrayAttr(symbolAttrs));
2780+ }
2781+
2782+ if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2783+ return mlir::failure();
2784+
2785+ // Set the operandSegmentSizes attribute
2786+ result.addAttribute(getOperandSegmentSizeAttr(),
2787+ builder.getDenseI32ArrayAttr(
2788+ {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
2789+ static_cast<int32_t>(iterOperands.size()),
2790+ static_cast<int32_t>(privateOperands.size())}));
2791+
27402792 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
27412793 regionArgs[i].type = argTypes[i];
27422794
2795+ // Parse the body region.
2796+ auto *body = result.addRegion();
27432797 if (parser.parseRegion(*body, regionArgs))
27442798 return mlir::failure();
27452799
@@ -2833,9 +2887,25 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
28332887 p << " -> " << getResultTypes();
28342888 printBlockTerminators = true;
28352889 }
2836- p.printOptionalAttrDictWithKeyword(
2837- (*this)->getAttrs(),
2838- {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
2890+
2891+ if (numPrivateBlockArgs() > 0) {
2892+ p << " private(";
2893+ llvm::interleaveComma(llvm::zip_equal(getPrivateSymsAttr(),
2894+ getPrivateVars(),
2895+ getRegionPrivateArgs()),
2896+ p, [&](auto it) {
2897+ p << std::get<0>(it) << " " << std::get<1>(it)
2898+ << " -> " << std::get<2>(it);
2899+ });
2900+ p << " : ";
2901+ llvm::interleaveComma(getPrivateVars(), p,
2902+ [&](auto it) { p << it.getType(); });
2903+ p << ")";
2904+ }
2905+
2906+ p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
2907+ {"unordered", "finalValue", "reduceAttrs",
2908+ "operandSegmentSizes", "private_syms"});
28392909 p << ' ';
28402910 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
28412911 printBlockTerminators);
0 commit comments