@@ -5033,29 +5033,33 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
50335033 mlir::OperationState &result) {
50345034 auto &builder = parser.getBuilder ();
50355035 // Parse an opening `(` followed by induction variables followed by `)`
5036- llvm::SmallVector<mlir::OpAsmParser::Argument, 4 > ivs;
5037- if (parser.parseArgumentList (ivs, mlir::OpAsmParser::Delimiter::Paren))
5036+ llvm::SmallVector<mlir::OpAsmParser::Argument, 4 > regionArgs;
5037+
5038+ if (parser.parseArgumentList (regionArgs, mlir::OpAsmParser::Delimiter::Paren))
50385039 return mlir::failure ();
50395040
5041+ llvm::SmallVector<mlir::Type> argTypes (regionArgs.size (),
5042+ builder.getIndexType ());
5043+
50405044 // Parse loop bounds.
50415045 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > lower;
50425046 if (parser.parseEqual () ||
5043- parser.parseOperandList (lower, ivs .size (),
5047+ parser.parseOperandList (lower, regionArgs .size (),
50445048 mlir::OpAsmParser::Delimiter::Paren) ||
50455049 parser.resolveOperands (lower, builder.getIndexType (), result.operands ))
50465050 return mlir::failure ();
50475051
50485052 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > upper;
50495053 if (parser.parseKeyword (" to" ) ||
5050- parser.parseOperandList (upper, ivs .size (),
5054+ parser.parseOperandList (upper, regionArgs .size (),
50515055 mlir::OpAsmParser::Delimiter::Paren) ||
50525056 parser.resolveOperands (upper, builder.getIndexType (), result.operands ))
50535057 return mlir::failure ();
50545058
50555059 // Parse step values.
50565060 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > steps;
50575061 if (parser.parseKeyword (" step" ) ||
5058- parser.parseOperandList (steps, ivs .size (),
5062+ parser.parseOperandList (steps, regionArgs .size (),
50595063 mlir::OpAsmParser::Delimiter::Paren) ||
50605064 parser.resolveOperands (steps, builder.getIndexType (), result.operands ))
50615065 return mlir::failure ();
@@ -5086,20 +5090,72 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
50865090 builder.getArrayAttr (arrayAttr));
50875091 }
50885092
5089- // Now parse the body.
5090- mlir::Region *body = result.addRegion ();
5091- for (auto &iv : ivs)
5092- iv.type = builder.getIndexType ();
5093- if (parser.parseRegion (*body, ivs))
5094- return mlir::failure ();
5093+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> localOperands;
5094+ if (succeeded (parser.parseOptionalKeyword (" local" ))) {
5095+ std::size_t oldArgTypesSize = argTypes.size ();
5096+ if (failed (parser.parseLParen ()))
5097+ return mlir::failure ();
5098+
5099+ llvm::SmallVector<mlir::SymbolRefAttr> localSymbolVec;
5100+ if (failed (parser.parseCommaSeparatedList ([&]() {
5101+ if (failed (parser.parseAttribute (localSymbolVec.emplace_back ())))
5102+ return mlir::failure ();
5103+
5104+ if (parser.parseOperand (localOperands.emplace_back ()) ||
5105+ parser.parseArrow () ||
5106+ parser.parseArgument (regionArgs.emplace_back ()))
5107+ return mlir::failure ();
5108+
5109+ return mlir::success ();
5110+ })))
5111+ return mlir::failure ();
5112+
5113+ if (failed (parser.parseColon ()))
5114+ return mlir::failure ();
5115+
5116+ if (failed (parser.parseCommaSeparatedList ([&]() {
5117+ if (failed (parser.parseType (argTypes.emplace_back ())))
5118+ return mlir::failure ();
5119+
5120+ return mlir::success ();
5121+ })))
5122+ return mlir::failure ();
5123+
5124+ if (regionArgs.size () != argTypes.size ())
5125+ return parser.emitError (parser.getNameLoc (),
5126+ " mismatch in number of local arg and types" );
5127+
5128+ if (failed (parser.parseRParen ()))
5129+ return mlir::failure ();
5130+
5131+ for (auto operandType : llvm::zip_equal (
5132+ localOperands, llvm::drop_begin (argTypes, oldArgTypesSize)))
5133+ if (parser.resolveOperand (std::get<0 >(operandType),
5134+ std::get<1 >(operandType), result.operands ))
5135+ return mlir::failure ();
5136+
5137+ llvm::SmallVector<mlir::Attribute> symbolAttrs (localSymbolVec.begin (),
5138+ localSymbolVec.end ());
5139+ result.addAttribute (getLocalSymsAttrName (result.name ),
5140+ builder.getArrayAttr (symbolAttrs));
5141+ }
50955142
50965143 // Set `operandSegmentSizes` attribute.
50975144 result.addAttribute (DoConcurrentLoopOp::getOperandSegmentSizeAttr (),
50985145 builder.getDenseI32ArrayAttr (
50995146 {static_cast <int32_t >(lower.size ()),
51005147 static_cast <int32_t >(upper.size ()),
51015148 static_cast <int32_t >(steps.size ()),
5102- static_cast <int32_t >(reduceOperands.size ())}));
5149+ static_cast <int32_t >(reduceOperands.size ()),
5150+ static_cast <int32_t >(localOperands.size ())}));
5151+
5152+ // Now parse the body.
5153+ for (auto [arg, type] : llvm::zip_equal (regionArgs, argTypes))
5154+ arg.type = type;
5155+
5156+ mlir::Region *body = result.addRegion ();
5157+ if (parser.parseRegion (*body, regionArgs))
5158+ return mlir::failure ();
51035159
51045160 // Parse attributes.
51055161 if (parser.parseOptionalAttrDict (result.attributes ))
@@ -5109,8 +5165,9 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
51095165}
51105166
51115167void fir::DoConcurrentLoopOp::print (mlir::OpAsmPrinter &p) {
5112- p << " (" << getBody ()->getArguments () << " ) = (" << getLowerBound ()
5113- << " ) to (" << getUpperBound () << " ) step (" << getStep () << " )" ;
5168+ p << " (" << getBody ()->getArguments ().slice (0 , getNumInductionVars ())
5169+ << " ) = (" << getLowerBound () << " ) to (" << getUpperBound () << " ) step ("
5170+ << getStep () << " )" ;
51145171
51155172 if (!getReduceOperands ().empty ()) {
51165173 p << " reduce(" ;
@@ -5123,12 +5180,27 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
51235180 p << ' )' ;
51245181 }
51255182
5183+ if (!getLocalVars ().empty ()) {
5184+ p << " local(" ;
5185+ llvm::interleaveComma (llvm::zip_equal (getLocalSymsAttr (), getLocalVars (),
5186+ getRegionLocalArgs ()),
5187+ p, [&](auto it) {
5188+ p << std::get<0 >(it) << " " << std::get<1 >(it)
5189+ << " -> " << std::get<2 >(it);
5190+ });
5191+ p << " : " ;
5192+ llvm::interleaveComma (getLocalVars (), p,
5193+ [&](auto it) { p << it.getType (); });
5194+ p << " )" ;
5195+ }
5196+
51265197 p << ' ' ;
51275198 p.printRegion (getRegion (), /* printEntryBlockArgs=*/ false );
51285199 p.printOptionalAttrDict (
51295200 (*this )->getAttrs (),
51305201 /* elidedAttrs=*/ {DoConcurrentLoopOp::getOperandSegmentSizeAttr (),
5131- DoConcurrentLoopOp::getReduceAttrsAttrName ()});
5202+ DoConcurrentLoopOp::getReduceAttrsAttrName (),
5203+ DoConcurrentLoopOp::getLocalSymsAttrName ()});
51325204}
51335205
51345206llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions () {
@@ -5139,6 +5211,7 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
51395211 mlir::Operation::operand_range lbValues = getLowerBound ();
51405212 mlir::Operation::operand_range ubValues = getUpperBound ();
51415213 mlir::Operation::operand_range stepValues = getStep ();
5214+ mlir::Operation::operand_range localVars = getLocalVars ();
51425215
51435216 if (lbValues.empty ())
51445217 return emitOpError (
@@ -5152,11 +5225,13 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
51525225 // Check that the body defines the same number of block arguments as the
51535226 // number of tuple elements in step.
51545227 mlir::Block *body = getBody ();
5155- if (body->getNumArguments () != stepValues.size ())
5228+ unsigned numIndVarArgs = body->getNumArguments () - localVars.size ();
5229+
5230+ if (numIndVarArgs != stepValues.size ())
51565231 return emitOpError () << " expects the same number of induction variables: "
51575232 << body->getNumArguments ()
51585233 << " as bound and step values: " << stepValues.size ();
5159- for (auto arg : body->getArguments ())
5234+ for (auto arg : body->getArguments (). slice ( 0 , numIndVarArgs) )
51605235 if (!arg.getType ().isIndex ())
51615236 return emitOpError (
51625237 " expects arguments for the induction variable to be of index type" );
@@ -5171,7 +5246,8 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
51715246
51725247std::optional<llvm::SmallVector<mlir::Value>>
51735248fir::DoConcurrentLoopOp::getLoopInductionVars () {
5174- return llvm::SmallVector<mlir::Value>{getBody ()->getArguments ()};
5249+ return llvm::SmallVector<mlir::Value>{
5250+ getBody ()->getArguments ().slice (0 , getLowerBound ().size ())};
51755251}
51765252
51775253// ===----------------------------------------------------------------------===//
0 commit comments