@@ -2563,8 +2563,9 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
2563
2563
2564
2564
// Parse the optional initial iteration arguments.
2565
2565
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
2566
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
2567
2566
llvm::SmallVector<mlir::Type> argTypes;
2567
+
2568
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
2568
2569
bool prependCount = false ;
2569
2570
regionArgs.push_back (inductionVariable);
2570
2571
@@ -2589,15 +2590,6 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
2589
2590
prependCount = true ;
2590
2591
}
2591
2592
2592
- // Set the operandSegmentSizes attribute
2593
- result.addAttribute (getOperandSegmentSizeAttr (),
2594
- builder.getDenseI32ArrayAttr (
2595
- {1 , 1 , 1 , static_cast <int32_t >(reduceOperands.size ()),
2596
- static_cast <int32_t >(iterOperands.size ()), 0 }));
2597
-
2598
- if (parser.parseOptionalAttrDictWithKeyword (result.attributes ))
2599
- return mlir::failure ();
2600
-
2601
2593
// Induction variable.
2602
2594
if (prependCount)
2603
2595
result.addAttribute (DoLoopOp::getFinalValueAttrName (result.name ),
@@ -2606,15 +2598,77 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
2606
2598
argTypes.push_back (indexType);
2607
2599
// Loop carried variables
2608
2600
argTypes.append (result.types .begin (), result.types .end ());
2609
- // Parse the body region.
2610
- auto *body = result.addRegion ();
2601
+
2611
2602
if (regionArgs.size () != argTypes.size ())
2612
2603
return parser.emitError (
2613
2604
parser.getNameLoc (),
2614
2605
" mismatch in number of loop-carried values and defined values" );
2606
+
2607
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
2608
+ if (succeeded (parser.parseOptionalKeyword (" private" ))) {
2609
+ std::size_t oldArgTypesSize = argTypes.size ();
2610
+ if (failed (parser.parseLParen ()))
2611
+ return mlir::failure ();
2612
+
2613
+ llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
2614
+ if (failed (parser.parseCommaSeparatedList ([&]() {
2615
+ if (failed (parser.parseAttribute (privateSymbolVec.emplace_back ())))
2616
+ return mlir::failure ();
2617
+
2618
+ if (parser.parseOperand (privateOperands.emplace_back ()) ||
2619
+ parser.parseArrow () ||
2620
+ parser.parseArgument (regionArgs.emplace_back ()))
2621
+ return mlir::failure ();
2622
+
2623
+ return mlir::success ();
2624
+ })))
2625
+ return mlir::failure ();
2626
+
2627
+ if (failed (parser.parseColon ()))
2628
+ return mlir::failure ();
2629
+
2630
+ if (failed (parser.parseCommaSeparatedList ([&]() {
2631
+ if (failed (parser.parseType (argTypes.emplace_back ())))
2632
+ return mlir::failure ();
2633
+
2634
+ return mlir::success ();
2635
+ })))
2636
+ return mlir::failure ();
2637
+
2638
+ if (regionArgs.size () != argTypes.size ())
2639
+ return parser.emitError (parser.getNameLoc (),
2640
+ " mismatch in number of private arg and types" );
2641
+
2642
+ if (failed (parser.parseRParen ()))
2643
+ return mlir::failure ();
2644
+
2645
+ for (auto operandType : llvm::zip_equal (
2646
+ privateOperands, llvm::drop_begin (argTypes, oldArgTypesSize)))
2647
+ if (parser.resolveOperand (std::get<0 >(operandType),
2648
+ std::get<1 >(operandType), result.operands ))
2649
+ return mlir::failure ();
2650
+
2651
+ llvm::SmallVector<mlir::Attribute> symbolAttrs (privateSymbolVec.begin (),
2652
+ privateSymbolVec.end ());
2653
+ result.addAttribute (getPrivateSymsAttrName (result.name ),
2654
+ builder.getArrayAttr (symbolAttrs));
2655
+ }
2656
+
2657
+ if (parser.parseOptionalAttrDictWithKeyword (result.attributes ))
2658
+ return mlir::failure ();
2659
+
2660
+ // Set the operandSegmentSizes attribute
2661
+ result.addAttribute (getOperandSegmentSizeAttr (),
2662
+ builder.getDenseI32ArrayAttr (
2663
+ {1 , 1 , 1 , static_cast <int32_t >(reduceOperands.size ()),
2664
+ static_cast <int32_t >(iterOperands.size ()),
2665
+ static_cast <int32_t >(privateOperands.size ())}));
2666
+
2615
2667
for (size_t i = 0 , e = regionArgs.size (); i != e; ++i)
2616
2668
regionArgs[i].type = argTypes[i];
2617
2669
2670
+ // Parse the body region.
2671
+ auto *body = result.addRegion ();
2618
2672
if (parser.parseRegion (*body, regionArgs))
2619
2673
return mlir::failure ();
2620
2674
@@ -2708,9 +2762,25 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
2708
2762
p << " -> " << getResultTypes ();
2709
2763
printBlockTerminators = true ;
2710
2764
}
2711
- p.printOptionalAttrDictWithKeyword (
2712
- (*this )->getAttrs (),
2713
- {" unordered" , " finalValue" , " reduceAttrs" , " operandSegmentSizes" });
2765
+
2766
+ if (numPrivateBlockArgs () > 0 ) {
2767
+ p << " private(" ;
2768
+ llvm::interleaveComma (llvm::zip_equal (getPrivateSymsAttr (),
2769
+ getPrivateVars (),
2770
+ getRegionPrivateArgs ()),
2771
+ p, [&](auto it) {
2772
+ p << std::get<0 >(it) << " " << std::get<1 >(it)
2773
+ << " -> " << std::get<2 >(it);
2774
+ });
2775
+ p << " : " ;
2776
+ llvm::interleaveComma (getPrivateVars (), p,
2777
+ [&](auto it) { p << it.getType (); });
2778
+ p << " )" ;
2779
+ }
2780
+
2781
+ p.printOptionalAttrDictWithKeyword ((*this )->getAttrs (),
2782
+ {" unordered" , " finalValue" , " reduceAttrs" ,
2783
+ " operandSegmentSizes" , " private_syms" });
2714
2784
p << ' ' ;
2715
2785
p.printRegion (getRegion (), /* printEntryBlockArgs=*/ false ,
2716
2786
printBlockTerminators);
0 commit comments