@@ -2688,8 +2688,9 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
2688
2688
2689
2689
// Parse the optional initial iteration arguments.
2690
2690
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
2691
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
2692
2691
llvm::SmallVector<mlir::Type> argTypes;
2692
+
2693
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
2693
2694
bool prependCount = false ;
2694
2695
regionArgs.push_back (inductionVariable);
2695
2696
@@ -2714,15 +2715,6 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
2714
2715
prependCount = true ;
2715
2716
}
2716
2717
2717
- // Set the operandSegmentSizes attribute
2718
- result.addAttribute (getOperandSegmentSizeAttr (),
2719
- builder.getDenseI32ArrayAttr (
2720
- {1 , 1 , 1 , static_cast <int32_t >(reduceOperands.size ()),
2721
- static_cast <int32_t >(iterOperands.size ()), 0 }));
2722
-
2723
- if (parser.parseOptionalAttrDictWithKeyword (result.attributes ))
2724
- return mlir::failure ();
2725
-
2726
2718
// Induction variable.
2727
2719
if (prependCount)
2728
2720
result.addAttribute (DoLoopOp::getFinalValueAttrName (result.name ),
@@ -2731,15 +2723,77 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
2731
2723
argTypes.push_back (indexType);
2732
2724
// Loop carried variables
2733
2725
argTypes.append (result.types .begin (), result.types .end ());
2734
- // Parse the body region.
2735
- auto *body = result.addRegion ();
2726
+
2736
2727
if (regionArgs.size () != argTypes.size ())
2737
2728
return parser.emitError (
2738
2729
parser.getNameLoc (),
2739
2730
" mismatch in number of loop-carried values and defined values" );
2731
+
2732
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
2733
+ if (succeeded (parser.parseOptionalKeyword (" private" ))) {
2734
+ std::size_t oldArgTypesSize = argTypes.size ();
2735
+ if (failed (parser.parseLParen ()))
2736
+ return mlir::failure ();
2737
+
2738
+ llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
2739
+ if (failed (parser.parseCommaSeparatedList ([&]() {
2740
+ if (failed (parser.parseAttribute (privateSymbolVec.emplace_back ())))
2741
+ return mlir::failure ();
2742
+
2743
+ if (parser.parseOperand (privateOperands.emplace_back ()) ||
2744
+ parser.parseArrow () ||
2745
+ parser.parseArgument (regionArgs.emplace_back ()))
2746
+ return mlir::failure ();
2747
+
2748
+ return mlir::success ();
2749
+ })))
2750
+ return mlir::failure ();
2751
+
2752
+ if (failed (parser.parseColon ()))
2753
+ return mlir::failure ();
2754
+
2755
+ if (failed (parser.parseCommaSeparatedList ([&]() {
2756
+ if (failed (parser.parseType (argTypes.emplace_back ())))
2757
+ return mlir::failure ();
2758
+
2759
+ return mlir::success ();
2760
+ })))
2761
+ return mlir::failure ();
2762
+
2763
+ if (regionArgs.size () != argTypes.size ())
2764
+ return parser.emitError (parser.getNameLoc (),
2765
+ " mismatch in number of private arg and types" );
2766
+
2767
+ if (failed (parser.parseRParen ()))
2768
+ return mlir::failure ();
2769
+
2770
+ for (auto operandType : llvm::zip_equal (
2771
+ privateOperands, llvm::drop_begin (argTypes, oldArgTypesSize)))
2772
+ if (parser.resolveOperand (std::get<0 >(operandType),
2773
+ std::get<1 >(operandType), result.operands ))
2774
+ return mlir::failure ();
2775
+
2776
+ llvm::SmallVector<mlir::Attribute> symbolAttrs (privateSymbolVec.begin (),
2777
+ privateSymbolVec.end ());
2778
+ result.addAttribute (getPrivateSymsAttrName (result.name ),
2779
+ builder.getArrayAttr (symbolAttrs));
2780
+ }
2781
+
2782
+ if (parser.parseOptionalAttrDictWithKeyword (result.attributes ))
2783
+ return mlir::failure ();
2784
+
2785
+ // Set the operandSegmentSizes attribute
2786
+ result.addAttribute (getOperandSegmentSizeAttr (),
2787
+ builder.getDenseI32ArrayAttr (
2788
+ {1 , 1 , 1 , static_cast <int32_t >(reduceOperands.size ()),
2789
+ static_cast <int32_t >(iterOperands.size ()),
2790
+ static_cast <int32_t >(privateOperands.size ())}));
2791
+
2740
2792
for (size_t i = 0 , e = regionArgs.size (); i != e; ++i)
2741
2793
regionArgs[i].type = argTypes[i];
2742
2794
2795
+ // Parse the body region.
2796
+ auto *body = result.addRegion ();
2743
2797
if (parser.parseRegion (*body, regionArgs))
2744
2798
return mlir::failure ();
2745
2799
@@ -2833,9 +2887,25 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
2833
2887
p << " -> " << getResultTypes ();
2834
2888
printBlockTerminators = true ;
2835
2889
}
2836
- p.printOptionalAttrDictWithKeyword (
2837
- (*this )->getAttrs (),
2838
- {" unordered" , " finalValue" , " reduceAttrs" , " operandSegmentSizes" });
2890
+
2891
+ if (numPrivateBlockArgs () > 0 ) {
2892
+ p << " private(" ;
2893
+ llvm::interleaveComma (llvm::zip_equal (getPrivateSymsAttr (),
2894
+ getPrivateVars (),
2895
+ getRegionPrivateArgs ()),
2896
+ p, [&](auto it) {
2897
+ p << std::get<0 >(it) << " " << std::get<1 >(it)
2898
+ << " -> " << std::get<2 >(it);
2899
+ });
2900
+ p << " : " ;
2901
+ llvm::interleaveComma (getPrivateVars (), p,
2902
+ [&](auto it) { p << it.getType (); });
2903
+ p << " )" ;
2904
+ }
2905
+
2906
+ p.printOptionalAttrDictWithKeyword ((*this )->getAttrs (),
2907
+ {" unordered" , " finalValue" , " reduceAttrs" ,
2908
+ " operandSegmentSizes" , " private_syms" });
2839
2909
p << ' ' ;
2840
2910
p.printRegion (getRegion (), /* printEntryBlockArgs=*/ false ,
2841
2911
printBlockTerminators);
0 commit comments