@@ -2563,8 +2563,9 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
25632563
25642564 // Parse the optional initial iteration arguments.
25652565 llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
2566- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
25672566 llvm::SmallVector<mlir::Type> argTypes;
2567+
2568+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
25682569 bool prependCount = false ;
25692570 regionArgs.push_back (inductionVariable);
25702571
@@ -2589,15 +2590,6 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
25892590 prependCount = true ;
25902591 }
25912592
2592- // Set the operandSegmentSizes attribute
2593- result.addAttribute (getOperandSegmentSizeAttr (),
2594- builder.getDenseI32ArrayAttr (
2595- {1 , 1 , 1 , static_cast <int32_t >(reduceOperands.size ()),
2596- static_cast <int32_t >(iterOperands.size ()), 0 }));
2597-
2598- if (parser.parseOptionalAttrDictWithKeyword (result.attributes ))
2599- return mlir::failure ();
2600-
26012593 // Induction variable.
26022594 if (prependCount)
26032595 result.addAttribute (DoLoopOp::getFinalValueAttrName (result.name ),
@@ -2606,15 +2598,77 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
26062598 argTypes.push_back (indexType);
26072599 // Loop carried variables
26082600 argTypes.append (result.types .begin (), result.types .end ());
2609- // Parse the body region.
2610- auto *body = result.addRegion ();
2601+
26112602 if (regionArgs.size () != argTypes.size ())
26122603 return parser.emitError (
26132604 parser.getNameLoc (),
26142605 " mismatch in number of loop-carried values and defined values" );
2606+
2607+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
2608+ if (succeeded (parser.parseOptionalKeyword (" private" ))) {
2609+ std::size_t oldArgTypesSize = argTypes.size ();
2610+ if (failed (parser.parseLParen ()))
2611+ return mlir::failure ();
2612+
2613+ llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
2614+ if (failed (parser.parseCommaSeparatedList ([&]() {
2615+ if (failed (parser.parseAttribute (privateSymbolVec.emplace_back ())))
2616+ return mlir::failure ();
2617+
2618+ if (parser.parseOperand (privateOperands.emplace_back ()) ||
2619+ parser.parseArrow () ||
2620+ parser.parseArgument (regionArgs.emplace_back ()))
2621+ return mlir::failure ();
2622+
2623+ return mlir::success ();
2624+ })))
2625+ return mlir::failure ();
2626+
2627+ if (failed (parser.parseColon ()))
2628+ return mlir::failure ();
2629+
2630+ if (failed (parser.parseCommaSeparatedList ([&]() {
2631+ if (failed (parser.parseType (argTypes.emplace_back ())))
2632+ return mlir::failure ();
2633+
2634+ return mlir::success ();
2635+ })))
2636+ return mlir::failure ();
2637+
2638+ if (regionArgs.size () != argTypes.size ())
2639+ return parser.emitError (parser.getNameLoc (),
2640+ " mismatch in number of private arg and types" );
2641+
2642+ if (failed (parser.parseRParen ()))
2643+ return mlir::failure ();
2644+
2645+ for (auto operandType : llvm::zip_equal (
2646+ privateOperands, llvm::drop_begin (argTypes, oldArgTypesSize)))
2647+ if (parser.resolveOperand (std::get<0 >(operandType),
2648+ std::get<1 >(operandType), result.operands ))
2649+ return mlir::failure ();
2650+
2651+ llvm::SmallVector<mlir::Attribute> symbolAttrs (privateSymbolVec.begin (),
2652+ privateSymbolVec.end ());
2653+ result.addAttribute (getPrivateSymsAttrName (result.name ),
2654+ builder.getArrayAttr (symbolAttrs));
2655+ }
2656+
2657+ if (parser.parseOptionalAttrDictWithKeyword (result.attributes ))
2658+ return mlir::failure ();
2659+
2660+ // Set the operandSegmentSizes attribute
2661+ result.addAttribute (getOperandSegmentSizeAttr (),
2662+ builder.getDenseI32ArrayAttr (
2663+ {1 , 1 , 1 , static_cast <int32_t >(reduceOperands.size ()),
2664+ static_cast <int32_t >(iterOperands.size ()),
2665+ static_cast <int32_t >(privateOperands.size ())}));
2666+
26152667 for (size_t i = 0 , e = regionArgs.size (); i != e; ++i)
26162668 regionArgs[i].type = argTypes[i];
26172669
2670+ // Parse the body region.
2671+ auto *body = result.addRegion ();
26182672 if (parser.parseRegion (*body, regionArgs))
26192673 return mlir::failure ();
26202674
@@ -2708,9 +2762,25 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
27082762 p << " -> " << getResultTypes ();
27092763 printBlockTerminators = true ;
27102764 }
2711- p.printOptionalAttrDictWithKeyword (
2712- (*this )->getAttrs (),
2713- {" unordered" , " finalValue" , " reduceAttrs" , " operandSegmentSizes" });
2765+
2766+ if (numPrivateBlockArgs () > 0 ) {
2767+ p << " private(" ;
2768+ llvm::interleaveComma (llvm::zip_equal (getPrivateSymsAttr (),
2769+ getPrivateVars (),
2770+ getRegionPrivateArgs ()),
2771+ p, [&](auto it) {
2772+ p << std::get<0 >(it) << " " << std::get<1 >(it)
2773+ << " -> " << std::get<2 >(it);
2774+ });
2775+ p << " : " ;
2776+ llvm::interleaveComma (getPrivateVars (), p,
2777+ [&](auto it) { p << it.getType (); });
2778+ p << " )" ;
2779+ }
2780+
2781+ p.printOptionalAttrDictWithKeyword ((*this )->getAttrs (),
2782+ {" unordered" , " finalValue" , " reduceAttrs" ,
2783+ " operandSegmentSizes" , " private_syms" });
27142784 p << ' ' ;
27152785 p.printRegion (getRegion (), /* printEntryBlockArgs=*/ false ,
27162786 printBlockTerminators);
0 commit comments