@@ -472,16 +472,20 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
472472// ===----------------------------------------------------------------------===//
473473
474474static ParseResult parseClauseWithRegionArgs (
475- OpAsmParser &parser, Region ®ion,
475+ OpAsmParser &parser,
476476 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
477477 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref, ArrayAttr &symbols,
478- SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs) {
478+ SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
479+ bool parseParens = true ) {
479480 SmallVector<SymbolRefAttr> reductionVec;
480481 SmallVector<bool > isByRefVec;
481482 unsigned regionArgOffset = regionPrivateArgs.size ();
482483
484+ OpAsmParser::Delimiter delimiter = parseParens ? OpAsmParser::Delimiter::Paren
485+ : OpAsmParser::Delimiter::None;
486+
483487 if (failed (
484- parser.parseCommaSeparatedList (OpAsmParser::Delimiter::Paren , [&]() {
488+ parser.parseCommaSeparatedList (delimiter , [&]() {
485489 ParseResult optionalByref = parser.parseOptionalKeyword (" byref" );
486490 if (parser.parseAttribute (reductionVec.emplace_back ()) ||
487491 parser.parseOperand (operands.emplace_back ()) ||
@@ -536,15 +540,15 @@ static ParseResult parseParallelRegion(
536540 llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
537541
538542 if (succeeded (parser.parseOptionalKeyword (" reduction" ))) {
539- if (failed (parseClauseWithRegionArgs (parser, region, reductionVars,
543+ if (failed (parseClauseWithRegionArgs (parser, reductionVars,
540544 reductionTypes, reductionByref,
541545 reductionSyms, regionPrivateArgs)))
542546 return failure ();
543547 }
544548
545549 if (succeeded (parser.parseOptionalKeyword (" private" ))) {
546550 auto privateByref = DenseBoolArrayAttr::get (parser.getContext (), {});
547- if (failed (parseClauseWithRegionArgs (parser, region, privateVars,
551+ if (failed (parseClauseWithRegionArgs (parser, privateVars,
548552 privateTypes, privateByref,
549553 privateSyms, regionPrivateArgs)))
550554 return failure ();
@@ -597,48 +601,26 @@ static ParseResult parseReductionVarList(
597601 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
598602 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
599603 ArrayAttr &reductionSyms) {
600- SmallVector<SymbolRefAttr> reductionVec;
601- SmallVector<bool > isByRefVec;
602- if (failed (parser.parseCommaSeparatedList ([&]() {
603- ParseResult optionalByref = parser.parseOptionalKeyword (" byref" );
604- if (parser.parseAttribute (reductionVec.emplace_back ()) ||
605- parser.parseArrow () ||
606- parser.parseOperand (reductionVars.emplace_back ()) ||
607- parser.parseColonType (reductionTypes.emplace_back ()))
608- return failure ();
609- isByRefVec.push_back (optionalByref.succeeded ());
610- return success ();
611- })))
612- return failure ();
613- reductionByref = makeDenseBoolArrayAttr (parser.getContext (), isByRefVec);
614- SmallVector<Attribute> reductions (reductionVec.begin (), reductionVec.end ());
615- reductionSyms = ArrayAttr::get (parser.getContext (), reductions);
616- return success ();
604+ llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
605+ return parseClauseWithRegionArgs (parser, reductionVars, reductionTypes,
606+ reductionByref, reductionSyms,
607+ regionPrivateArgs, /* parseParens=*/ false );
617608}
618609
619610// / Print Reduction clause
620- static void
621- printReductionVarList (OpAsmPrinter &p, Operation *op,
622- OperandRange reductionVars, TypeRange reductionTypes,
623- std::optional<DenseBoolArrayAttr> reductionByref,
624- std::optional<ArrayAttr> reductionSyms) {
625- auto getByRef = [&](unsigned i) -> const char * {
626- if (!reductionByref || !*reductionByref)
627- return " " ;
628- assert (reductionByref->empty () || i < reductionByref->size ());
629- if (!reductionByref->empty () && (*reductionByref)[i])
630- return " byref " ;
631- return " " ;
632- };
633-
634- for (unsigned i = 0 , e = reductionVars.size (); i < e; ++i) {
635- if (i != 0 )
636- p << " , " ;
637- p << getByRef (i) << (*reductionSyms)[i] << " -> " << reductionVars[i]
638- << " : " << reductionVars[i].getType ();
611+ static void printReductionVarList (OpAsmPrinter &p, Operation *op,
612+ OperandRange reductionVars,
613+ TypeRange reductionTypes,
614+ DenseBoolArrayAttr reductionByref,
615+ ArrayAttr reductionSyms) {
616+ if (reductionSyms) {
617+ auto *argsBegin = op->getRegion (0 ).front ().getArguments ().begin ();
618+ MutableArrayRef argsSubrange (argsBegin, argsBegin + reductionTypes.size ());
619+ printClauseWithRegionArgs (p, op, argsSubrange, llvm::StringRef (),
620+ reductionVars, reductionTypes, reductionByref,
621+ reductionSyms);
639622 }
640623}
641-
642624// / Verifies Reduction Clause
643625static LogicalResult
644626verifyReductionVarList (Operation *op, std::optional<ArrayAttr> reductionSyms,
@@ -1824,7 +1806,7 @@ parseWsloop(OpAsmParser &parser, Region ®ion,
18241806 // Parse an optional reduction clause
18251807 llvm::SmallVector<OpAsmParser::Argument> privates;
18261808 if (succeeded (parser.parseOptionalKeyword (" reduction" ))) {
1827- if (failed (parseClauseWithRegionArgs (parser, region, reductionOperands,
1809+ if (failed (parseClauseWithRegionArgs (parser, reductionOperands,
18281810 reductionTypes, reductionByRef,
18291811 reductionSymbols, privates)))
18301812 return failure ();
0 commit comments