@@ -487,9 +487,11 @@ struct PrivateParseArgs {
487487 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
488488 llvm::SmallVectorImpl<Type> &types;
489489 ArrayAttr &syms;
490+ DenseI64ArrayAttr *mapIndices;
490491 PrivateParseArgs (SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
491- SmallVectorImpl<Type> &types, ArrayAttr &syms)
492- : vars(vars), types(types), syms(syms) {}
492+ SmallVectorImpl<Type> &types, ArrayAttr &syms,
493+ DenseI64ArrayAttr *mapIndices = nullptr )
494+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
493495};
494496struct ReductionParseArgs {
495497 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +519,10 @@ static ParseResult parseClauseWithRegionArgs(
517519 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
518520 SmallVectorImpl<Type> &types,
519521 SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
520- ArrayAttr *symbols = nullptr , DenseBoolArrayAttr *byref = nullptr ) {
522+ ArrayAttr *symbols = nullptr , DenseI64ArrayAttr *mapIndices = nullptr ,
523+ DenseBoolArrayAttr *byref = nullptr ) {
521524 SmallVector<SymbolRefAttr> symbolVec;
525+ SmallVector<int64_t > mapIndicesVec;
522526 SmallVector<bool > isByRefVec;
523527 unsigned regionArgOffset = regionPrivateArgs.size ();
524528
@@ -538,6 +542,16 @@ static ParseResult parseClauseWithRegionArgs(
538542 parser.parseArgument (regionPrivateArgs.emplace_back ()))
539543 return failure ();
540544
545+ if (mapIndices) {
546+ if (parser.parseOptionalLSquare ().succeeded ()) {
547+ if (parser.parseKeyword (" map_idx" ) || parser.parseEqual () ||
548+ parser.parseInteger (mapIndicesVec.emplace_back ()) ||
549+ parser.parseRSquare ())
550+ return failure ();
551+ } else
552+ mapIndicesVec.push_back (-1 );
553+ }
554+
541555 return success ();
542556 }))
543557 return failure ();
@@ -571,6 +585,10 @@ static ParseResult parseClauseWithRegionArgs(
571585 *symbols = ArrayAttr::get (parser.getContext (), symbolAttrs);
572586 }
573587
588+ if (!mapIndicesVec.empty ())
589+ *mapIndices =
590+ mlir::DenseI64ArrayAttr::get (parser.getContext (), mapIndicesVec);
591+
574592 if (byref)
575593 *byref = makeDenseBoolArrayAttr (parser.getContext (), isByRefVec);
576594
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
595613static ParseResult parseBlockArgClause (
596614 OpAsmParser &parser,
597615 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
598- StringRef keyword, std::optional<PrivateParseArgs> reductionArgs ) {
616+ StringRef keyword, std::optional<PrivateParseArgs> privateArgs ) {
599617 if (succeeded (parser.parseOptionalKeyword (keyword))) {
600- if (!reductionArgs )
618+ if (!privateArgs )
601619 return failure ();
602620
603- if (failed (parseClauseWithRegionArgs (parser, reductionArgs-> vars ,
604- reductionArgs ->types , entryBlockArgs,
605- &reductionArgs ->syms )))
621+ if (failed (parseClauseWithRegionArgs (
622+ parser, privateArgs-> vars , privateArgs ->types , entryBlockArgs,
623+ &privateArgs ->syms , privateArgs-> mapIndices )))
606624 return failure ();
607625 }
608626 return success ();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
618636
619637 if (failed (parseClauseWithRegionArgs (
620638 parser, reductionArgs->vars , reductionArgs->types , entryBlockArgs,
621- &reductionArgs->syms , &reductionArgs->byref )))
639+ &reductionArgs->syms , /* mapIndices=*/ nullptr ,
640+ &reductionArgs->byref )))
622641 return failure ();
623642 }
624643 return success ();
@@ -674,12 +693,14 @@ static ParseResult parseInReductionMapPrivateRegion(
674693 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
675694 SmallVectorImpl<Type> &mapTypes,
676695 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
677- llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
696+ llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
697+ DenseI64ArrayAttr &privateMaps) {
678698 AllRegionParseArgs args;
679699 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
680700 inReductionByref, inReductionSyms);
681701 args.mapArgs .emplace (mapVars, mapTypes);
682- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
702+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
703+ &privateMaps);
683704 return parseBlockArgRegion (parser, region, args);
684705}
685706
@@ -776,8 +797,10 @@ struct PrivatePrintArgs {
776797 ValueRange vars;
777798 TypeRange types;
778799 ArrayAttr syms;
779- PrivatePrintArgs (ValueRange vars, TypeRange types, ArrayAttr syms)
780- : vars(vars), types(types), syms(syms) {}
800+ DenseI64ArrayAttr mapIndices;
801+ PrivatePrintArgs (ValueRange vars, TypeRange types, ArrayAttr syms,
802+ DenseI64ArrayAttr mapIndices)
803+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
781804};
782805struct ReductionPrintArgs {
783806 ValueRange vars;
@@ -804,6 +827,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
804827 ValueRange argsSubrange,
805828 ValueRange operands, TypeRange types,
806829 ArrayAttr symbols = nullptr ,
830+ DenseI64ArrayAttr mapIndices = nullptr ,
807831 DenseBoolArrayAttr byref = nullptr ) {
808832 if (argsSubrange.empty ())
809833 return ;
@@ -815,21 +839,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
815839 symbols = ArrayAttr::get (ctx, values);
816840 }
817841
842+ if (!mapIndices) {
843+ llvm::SmallVector<int64_t > values (operands.size (), -1 );
844+ mapIndices = DenseI64ArrayAttr::get (ctx, values);
845+ }
846+
818847 if (!byref) {
819848 mlir::SmallVector<bool > values (operands.size (), false );
820849 byref = DenseBoolArrayAttr::get (ctx, values);
821850 }
822851
823- llvm::interleaveComma (
824- llvm::zip_equal (operands, argsSubrange, symbols, byref.asArrayRef ()), p,
825- [&p](auto t) {
826- auto [op, arg, sym, isByRef] = t;
827- if (isByRef)
828- p << " byref " ;
829- if (sym)
830- p << sym << " " ;
831- p << op << " -> " << arg;
832- });
852+ llvm::interleaveComma (llvm::zip_equal (operands, argsSubrange, symbols,
853+ mapIndices.asArrayRef (),
854+ byref.asArrayRef ()),
855+ p, [&p](auto t) {
856+ auto [op, arg, sym, map, isByRef] = t;
857+ if (isByRef)
858+ p << " byref " ;
859+ if (sym)
860+ p << sym << " " ;
861+
862+ p << op << " -> " << arg;
863+
864+ if (map != -1 )
865+ p << " [map_idx=" << map << " ]" ;
866+ });
833867 p << " : " ;
834868 llvm::interleaveComma (types, p);
835869 p << " ) " ;
@@ -849,7 +883,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
849883 if (privateArgs)
850884 printClauseWithRegionArgs (p, ctx, clauseName, argsSubrange,
851885 privateArgs->vars , privateArgs->types ,
852- privateArgs->syms );
886+ privateArgs->syms , privateArgs-> mapIndices );
853887}
854888
855889static void
@@ -859,7 +893,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
859893 if (reductionArgs)
860894 printClauseWithRegionArgs (p, ctx, clauseName, argsSubrange,
861895 reductionArgs->vars , reductionArgs->types ,
862- reductionArgs->syms , reductionArgs->byref );
896+ reductionArgs->syms , /* mapIndices=*/ nullptr ,
897+ reductionArgs->byref );
863898}
864899
865900static void printBlockArgRegion (OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -891,12 +926,13 @@ static void printInReductionMapPrivateRegion(
891926 OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
892927 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
893928 ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
894- ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
929+ ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
930+ DenseI64ArrayAttr privateMaps) {
895931 AllRegionPrintArgs args;
896932 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
897933 inReductionByref, inReductionSyms);
898934 args.mapArgs .emplace (mapVars, mapTypes);
899- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
935+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, privateMaps );
900936 printBlockArgRegion (p, op, region, args);
901937}
902938
@@ -908,7 +944,8 @@ static void printInReductionPrivateRegion(
908944 AllRegionPrintArgs args;
909945 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
910946 inReductionByref, inReductionSyms);
911- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
947+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
948+ /* mapIndices=*/ nullptr );
912949 printBlockArgRegion (p, op, region, args);
913950}
914951
@@ -921,7 +958,8 @@ static void printInReductionPrivateReductionRegion(
921958 AllRegionPrintArgs args;
922959 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
923960 inReductionByref, inReductionSyms);
924- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
961+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
962+ /* mapIndices=*/ nullptr );
925963 args.reductionArgs .emplace (reductionVars, reductionTypes, reductionByref,
926964 reductionSyms);
927965 printBlockArgRegion (p, op, region, args);
@@ -931,7 +969,8 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
931969 ValueRange privateVars, TypeRange privateTypes,
932970 ArrayAttr privateSyms) {
933971 AllRegionPrintArgs args;
934- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
972+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
973+ /* mapIndices=*/ nullptr );
935974 printBlockArgRegion (p, op, region, args);
936975}
937976
@@ -941,7 +980,8 @@ static void printPrivateReductionRegion(
941980 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
942981 ArrayAttr reductionSyms) {
943982 AllRegionPrintArgs args;
944- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
983+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
984+ /* mapIndices=*/ nullptr );
945985 args.reductionArgs .emplace (reductionVars, reductionTypes, reductionByref,
946986 reductionSyms);
947987 printBlockArgRegion (p, op, region, args);
@@ -1656,7 +1696,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
16561696 /* in_reduction_vars=*/ {}, /* in_reduction_byref=*/ nullptr ,
16571697 /* in_reduction_syms=*/ nullptr , clauses.isDevicePtrVars ,
16581698 clauses.mapVars , clauses.nowait , clauses.privateVars ,
1659- makeArrayAttr (ctx, clauses.privateSyms ), clauses.threadLimit );
1699+ makeArrayAttr (ctx, clauses.privateSyms ), clauses.threadLimit ,
1700+ /* private_maps=*/ nullptr );
16601701}
16611702
16621703LogicalResult TargetOp::verify () {
0 commit comments