@@ -4886,29 +4886,33 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
48864886 mlir::OperationState &result) {
48874887 auto &builder = parser.getBuilder ();
48884888 // Parse an opening `(` followed by induction variables followed by `)`
4889- llvm::SmallVector<mlir::OpAsmParser::Argument, 4 > ivs;
4890- if (parser.parseArgumentList (ivs, mlir::OpAsmParser::Delimiter::Paren))
4889+ llvm::SmallVector<mlir::OpAsmParser::Argument, 4 > regionArgs;
4890+
4891+ if (parser.parseArgumentList (regionArgs, mlir::OpAsmParser::Delimiter::Paren))
48914892 return mlir::failure ();
48924893
4894+ llvm::SmallVector<mlir::Type> argTypes (regionArgs.size (),
4895+ builder.getIndexType ());
4896+
48934897 // Parse loop bounds.
48944898 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > lower;
48954899 if (parser.parseEqual () ||
4896- parser.parseOperandList (lower, ivs .size (),
4900+ parser.parseOperandList (lower, regionArgs .size (),
48974901 mlir::OpAsmParser::Delimiter::Paren) ||
48984902 parser.resolveOperands (lower, builder.getIndexType (), result.operands ))
48994903 return mlir::failure ();
49004904
49014905 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > upper;
49024906 if (parser.parseKeyword (" to" ) ||
4903- parser.parseOperandList (upper, ivs .size (),
4907+ parser.parseOperandList (upper, regionArgs .size (),
49044908 mlir::OpAsmParser::Delimiter::Paren) ||
49054909 parser.resolveOperands (upper, builder.getIndexType (), result.operands ))
49064910 return mlir::failure ();
49074911
49084912 // Parse step values.
49094913 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > steps;
49104914 if (parser.parseKeyword (" step" ) ||
4911- parser.parseOperandList (steps, ivs .size (),
4915+ parser.parseOperandList (steps, regionArgs .size (),
49124916 mlir::OpAsmParser::Delimiter::Paren) ||
49134917 parser.resolveOperands (steps, builder.getIndexType (), result.operands ))
49144918 return mlir::failure ();
@@ -4939,20 +4943,72 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
49394943 builder.getArrayAttr (arrayAttr));
49404944 }
49414945
4942- // Now parse the body.
4943- mlir::Region *body = result.addRegion ();
4944- for (auto &iv : ivs)
4945- iv.type = builder.getIndexType ();
4946- if (parser.parseRegion (*body, ivs))
4947- return mlir::failure ();
4946+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
4947+ if (succeeded (parser.parseOptionalKeyword (" private" ))) {
4948+ std::size_t oldArgTypesSize = argTypes.size ();
4949+ if (failed (parser.parseLParen ()))
4950+ return mlir::failure ();
4951+
4952+ llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
4953+ if (failed (parser.parseCommaSeparatedList ([&]() {
4954+ if (failed (parser.parseAttribute (privateSymbolVec.emplace_back ())))
4955+ return mlir::failure ();
4956+
4957+ if (parser.parseOperand (privateOperands.emplace_back ()) ||
4958+ parser.parseArrow () ||
4959+ parser.parseArgument (regionArgs.emplace_back ()))
4960+ return mlir::failure ();
4961+
4962+ return mlir::success ();
4963+ })))
4964+ return mlir::failure ();
4965+
4966+ if (failed (parser.parseColon ()))
4967+ return mlir::failure ();
4968+
4969+ if (failed (parser.parseCommaSeparatedList ([&]() {
4970+ if (failed (parser.parseType (argTypes.emplace_back ())))
4971+ return mlir::failure ();
4972+
4973+ return mlir::success ();
4974+ })))
4975+ return mlir::failure ();
4976+
4977+ if (regionArgs.size () != argTypes.size ())
4978+ return parser.emitError (parser.getNameLoc (),
4979+ " mismatch in number of private arg and types" );
4980+
4981+ if (failed (parser.parseRParen ()))
4982+ return mlir::failure ();
4983+
4984+ for (auto operandType : llvm::zip_equal (
4985+ privateOperands, llvm::drop_begin (argTypes, oldArgTypesSize)))
4986+ if (parser.resolveOperand (std::get<0 >(operandType),
4987+ std::get<1 >(operandType), result.operands ))
4988+ return mlir::failure ();
4989+
4990+ llvm::SmallVector<mlir::Attribute> symbolAttrs (privateSymbolVec.begin (),
4991+ privateSymbolVec.end ());
4992+ result.addAttribute (getPrivateSymsAttrName (result.name ),
4993+ builder.getArrayAttr (symbolAttrs));
4994+ }
49484995
49494996 // Set `operandSegmentSizes` attribute.
49504997 result.addAttribute (DoConcurrentLoopOp::getOperandSegmentSizeAttr (),
49514998 builder.getDenseI32ArrayAttr (
49524999 {static_cast <int32_t >(lower.size ()),
49535000 static_cast <int32_t >(upper.size ()),
49545001 static_cast <int32_t >(steps.size ()),
4955- static_cast <int32_t >(reduceOperands.size ()), 0 }));
5002+ static_cast <int32_t >(reduceOperands.size ()),
5003+ static_cast <int32_t >(privateOperands.size ())}));
5004+
5005+ // Now parse the body.
5006+ for (auto [arg, type] : llvm::zip_equal (regionArgs, argTypes))
5007+ arg.type = type;
5008+
5009+ mlir::Region *body = result.addRegion ();
5010+ if (parser.parseRegion (*body, regionArgs))
5011+ return mlir::failure ();
49565012
49575013 // Parse attributes.
49585014 if (parser.parseOptionalAttrDict (result.attributes ))
@@ -4962,8 +5018,9 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
49625018}
49635019
49645020void fir::DoConcurrentLoopOp::print (mlir::OpAsmPrinter &p) {
4965- p << " (" << getBody ()->getArguments () << " ) = (" << getLowerBound ()
4966- << " ) to (" << getUpperBound () << " ) step (" << getStep () << " )" ;
5021+ p << " (" << getBody ()->getArguments ().slice (0 , getNumInductionVars ())
5022+ << " ) = (" << getLowerBound () << " ) to (" << getUpperBound () << " ) step ("
5023+ << getStep () << " )" ;
49675024
49685025 if (!getReduceOperands ().empty ()) {
49695026 p << " reduce(" ;
@@ -4976,12 +5033,28 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
49765033 p << ' )' ;
49775034 }
49785035
5036+ if (!getPrivateVars ().empty ()) {
5037+ p << " private(" ;
5038+ llvm::interleaveComma (llvm::zip_equal (getPrivateSymsAttr (),
5039+ getPrivateVars (),
5040+ getRegionPrivateArgs ()),
5041+ p, [&](auto it) {
5042+ p << std::get<0 >(it) << " " << std::get<1 >(it)
5043+ << " -> " << std::get<2 >(it);
5044+ });
5045+ p << " : " ;
5046+ llvm::interleaveComma (getPrivateVars (), p,
5047+ [&](auto it) { p << it.getType (); });
5048+ p << " )" ;
5049+ }
5050+
49795051 p << ' ' ;
49805052 p.printRegion (getRegion (), /* printEntryBlockArgs=*/ false );
49815053 p.printOptionalAttrDict (
49825054 (*this )->getAttrs (),
49835055 /* elidedAttrs=*/ {DoConcurrentLoopOp::getOperandSegmentSizeAttr (),
4984- DoConcurrentLoopOp::getReduceAttrsAttrName ()});
5056+ DoConcurrentLoopOp::getReduceAttrsAttrName (),
5057+ DoConcurrentLoopOp::getPrivateSymsAttrName ()});
49855058}
49865059
49875060llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions () {
@@ -4992,6 +5065,7 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
49925065 mlir::Operation::operand_range lbValues = getLowerBound ();
49935066 mlir::Operation::operand_range ubValues = getUpperBound ();
49945067 mlir::Operation::operand_range stepValues = getStep ();
5068+ mlir::Operation::operand_range privateVars = getPrivateVars ();
49955069
49965070 if (lbValues.empty ())
49975071 return emitOpError (
@@ -5005,11 +5079,13 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
50055079 // Check that the body defines the same number of block arguments as the
50065080 // number of tuple elements in step.
50075081 mlir::Block *body = getBody ();
5008- if (body->getNumArguments () != stepValues.size ())
5082+ unsigned numIndVarArgs = body->getNumArguments () - privateVars.size ();
5083+
5084+ if (numIndVarArgs != stepValues.size ())
50095085 return emitOpError () << " expects the same number of induction variables: "
50105086 << body->getNumArguments ()
50115087 << " as bound and step values: " << stepValues.size ();
5012- for (auto arg : body->getArguments ())
5088+ for (auto arg : body->getArguments (). slice ( 0 , numIndVarArgs) )
50135089 if (!arg.getType ().isIndex ())
50145090 return emitOpError (
50155091 " expects arguments for the induction variable to be of index type" );
@@ -5024,7 +5100,8 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
50245100
50255101std::optional<llvm::SmallVector<mlir::Value>>
50265102fir::DoConcurrentLoopOp::getLoopInductionVars () {
5027- return llvm::SmallVector<mlir::Value>{getBody ()->getArguments ()};
5103+ return llvm::SmallVector<mlir::Value>{
5104+ getBody ()->getArguments ().slice (0 , getLowerBound ().size ())};
50285105}
50295106
50305107// ===----------------------------------------------------------------------===//
0 commit comments