1515#include " mlir/Dialect/Func/IR/FuncOps.h"
1616#include " mlir/Dialect/LLVMIR/LLVMTypes.h"
1717#include " mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
18+ #include " mlir/Dialect/OpenMP/Utils.h"
1819#include " mlir/IR/Attributes.h"
1920#include " mlir/IR/BuiltinAttributes.h"
2021#include " mlir/IR/DialectImplementation.h"
@@ -487,9 +488,11 @@ struct PrivateParseArgs {
487488 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
488489 llvm::SmallVectorImpl<Type> &types;
489490 ArrayAttr &syms;
491+ ArrayAttr *mapIndices;
490492 PrivateParseArgs (SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
491- SmallVectorImpl<Type> &types, ArrayAttr &syms)
492- : vars(vars), types(types), syms(syms) {}
493+ SmallVectorImpl<Type> &types, ArrayAttr &syms,
494+ ArrayAttr *mapIndices = nullptr )
495+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
493496};
494497struct ReductionParseArgs {
495498 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +520,10 @@ static ParseResult parseClauseWithRegionArgs(
517520 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
518521 SmallVectorImpl<Type> &types,
519522 SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
520- ArrayAttr *symbols = nullptr , DenseBoolArrayAttr *byref = nullptr ) {
523+ ArrayAttr *symbols = nullptr , ArrayAttr *mapIndices = nullptr ,
524+ DenseBoolArrayAttr *byref = nullptr ) {
521525 SmallVector<SymbolRefAttr> symbolVec;
526+ SmallVector<int64_t > mapIndicesVec;
522527 SmallVector<bool > isByRefVec;
523528 unsigned regionArgOffset = regionPrivateArgs.size ();
524529
@@ -538,6 +543,16 @@ static ParseResult parseClauseWithRegionArgs(
538543 parser.parseArgument (regionPrivateArgs.emplace_back ()))
539544 return failure ();
540545
546+ if (mapIndices) {
547+ if (parser.parseOptionalLSquare ().succeeded ()) {
548+ if (parser.parseKeyword (" map_idx" ) || parser.parseEqual () ||
549+ parser.parseInteger (mapIndicesVec.emplace_back ()) ||
550+ parser.parseRSquare ())
551+ return failure ();
552+ } else
553+ mapIndicesVec.push_back (-1 );
554+ }
555+
541556 return success ();
542557 }))
543558 return failure ();
@@ -571,6 +586,9 @@ static ParseResult parseClauseWithRegionArgs(
571586 *symbols = ArrayAttr::get (parser.getContext (), symbolAttrs);
572587 }
573588
589+ if (!mapIndicesVec.empty ())
590+ *mapIndices = utils::makeI64ArrayAttr (mapIndicesVec, parser.getContext ());
591+
574592 if (byref)
575593 *byref = makeDenseBoolArrayAttr (parser.getContext (), isByRefVec);
576594
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
595613static ParseResult parseBlockArgClause (
596614 OpAsmParser &parser,
597615 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
598- StringRef keyword, std::optional<PrivateParseArgs> reductionArgs ) {
616+ StringRef keyword, std::optional<PrivateParseArgs> privateArgs ) {
599617 if (succeeded (parser.parseOptionalKeyword (keyword))) {
600- if (!reductionArgs )
618+ if (!privateArgs )
601619 return failure ();
602620
603- if (failed (parseClauseWithRegionArgs (parser, reductionArgs-> vars ,
604- reductionArgs ->types , entryBlockArgs,
605- &reductionArgs ->syms )))
621+ if (failed (parseClauseWithRegionArgs (
622+ parser, privateArgs-> vars , privateArgs ->types , entryBlockArgs,
623+ &privateArgs ->syms , privateArgs-> mapIndices )))
606624 return failure ();
607625 }
608626 return success ();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
618636
619637 if (failed (parseClauseWithRegionArgs (
620638 parser, reductionArgs->vars , reductionArgs->types , entryBlockArgs,
621- &reductionArgs->syms , &reductionArgs->byref )))
639+ &reductionArgs->syms , /* mapIndices=*/ nullptr ,
640+ &reductionArgs->byref )))
622641 return failure ();
623642 }
624643 return success ();
@@ -674,12 +693,14 @@ static ParseResult parseInReductionMapPrivateRegion(
674693 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
675694 SmallVectorImpl<Type> &mapTypes,
676695 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
677- llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
696+ llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
697+ ArrayAttr &privateMaps) {
678698 AllRegionParseArgs args;
679699 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
680700 inReductionByref, inReductionSyms);
681701 args.mapArgs .emplace (mapVars, mapTypes);
682- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
702+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
703+ &privateMaps);
683704 return parseBlockArgRegion (parser, region, args);
684705}
685706
@@ -776,8 +797,10 @@ struct PrivatePrintArgs {
776797 ValueRange vars;
777798 TypeRange types;
778799 ArrayAttr syms;
779- PrivatePrintArgs (ValueRange vars, TypeRange types, ArrayAttr syms)
780- : vars(vars), types(types), syms(syms) {}
800+ ArrayAttr mapIndices;
801+ PrivatePrintArgs (ValueRange vars, TypeRange types, ArrayAttr syms,
802+ ArrayAttr mapIndices)
803+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
781804};
782805struct ReductionPrintArgs {
783806 ValueRange vars;
@@ -804,6 +827,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
804827 ValueRange argsSubrange,
805828 ValueRange operands, TypeRange types,
806829 ArrayAttr symbols = nullptr ,
830+ ArrayAttr mapIndices = nullptr ,
807831 DenseBoolArrayAttr byref = nullptr ) {
808832 if (argsSubrange.empty ())
809833 return ;
@@ -815,21 +839,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
815839 symbols = ArrayAttr::get (ctx, values);
816840 }
817841
842+ if (!mapIndices) {
843+ llvm::SmallVector<Attribute> values (operands.size (), nullptr );
844+ mapIndices = ArrayAttr::get (ctx, values);
845+ }
846+
818847 if (!byref) {
819848 mlir::SmallVector<bool > values (operands.size (), false );
820849 byref = DenseBoolArrayAttr::get (ctx, values);
821850 }
822851
823- llvm::interleaveComma (
824- llvm::zip_equal (operands, argsSubrange, symbols, byref.asArrayRef ()), p,
825- [&p](auto t) {
826- auto [op, arg, sym, isByRef] = t;
827- if (isByRef)
828- p << " byref " ;
829- if (sym)
830- p << sym << " " ;
831- p << op << " -> " << arg;
832- });
852+ llvm::interleaveComma (llvm::zip_equal (operands, argsSubrange, symbols,
853+ mapIndices, byref.asArrayRef ()),
854+ p, [&p](auto t) {
855+ auto [op, arg, sym, map, isByRef] = t;
856+ if (isByRef)
857+ p << " byref " ;
858+ if (sym)
859+ p << sym << " " ;
860+
861+ p << op << " -> " << arg;
862+
863+ if (map)
864+ p << " [map_idx="
865+ << llvm::cast<IntegerAttr>(map).getInt () << " ]" ;
866+ });
833867 p << " : " ;
834868 llvm::interleaveComma (types, p);
835869 p << " ) " ;
@@ -849,7 +883,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
849883 if (privateArgs)
850884 printClauseWithRegionArgs (p, ctx, clauseName, argsSubrange,
851885 privateArgs->vars , privateArgs->types ,
852- privateArgs->syms );
886+ privateArgs->syms , privateArgs-> mapIndices );
853887}
854888
855889static void
@@ -859,7 +893,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
859893 if (reductionArgs)
860894 printClauseWithRegionArgs (p, ctx, clauseName, argsSubrange,
861895 reductionArgs->vars , reductionArgs->types ,
862- reductionArgs->syms , reductionArgs->byref );
896+ reductionArgs->syms , nullptr ,
897+ reductionArgs->byref );
863898}
864899
865900static void printBlockArgRegion (OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -891,12 +926,13 @@ static void printInReductionMapPrivateRegion(
891926 OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
892927 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
893928 ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
894- ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
929+ ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
930+ ArrayAttr privateMaps) {
895931 AllRegionPrintArgs args;
896932 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
897933 inReductionByref, inReductionSyms);
898934 args.mapArgs .emplace (mapVars, mapTypes);
899- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
935+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, privateMaps );
900936 printBlockArgRegion (p, op, region, args);
901937}
902938
@@ -908,7 +944,7 @@ static void printInReductionPrivateRegion(
908944 AllRegionPrintArgs args;
909945 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
910946 inReductionByref, inReductionSyms);
911- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
947+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, nullptr );
912948 printBlockArgRegion (p, op, region, args);
913949}
914950
@@ -921,7 +957,7 @@ static void printInReductionPrivateReductionRegion(
921957 AllRegionPrintArgs args;
922958 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
923959 inReductionByref, inReductionSyms);
924- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
960+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, nullptr );
925961 args.reductionArgs .emplace (reductionVars, reductionTypes, reductionByref,
926962 reductionSyms);
927963 printBlockArgRegion (p, op, region, args);
@@ -931,7 +967,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
931967 ValueRange privateVars, TypeRange privateTypes,
932968 ArrayAttr privateSyms) {
933969 AllRegionPrintArgs args;
934- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
970+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, nullptr );
935971 printBlockArgRegion (p, op, region, args);
936972}
937973
@@ -941,7 +977,7 @@ static void printPrivateReductionRegion(
941977 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
942978 ArrayAttr reductionSyms) {
943979 AllRegionPrintArgs args;
944- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
980+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, nullptr );
945981 args.reductionArgs .emplace (reductionVars, reductionTypes, reductionByref,
946982 reductionSyms);
947983 printBlockArgRegion (p, op, region, args);
@@ -1656,7 +1692,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
16561692 /* in_reduction_vars=*/ {}, /* in_reduction_byref=*/ nullptr ,
16571693 /* in_reduction_syms=*/ nullptr , clauses.isDevicePtrVars ,
16581694 clauses.mapVars , clauses.nowait , clauses.privateVars ,
1659- makeArrayAttr (ctx, clauses.privateSyms ), clauses.threadLimit );
1695+ makeArrayAttr (ctx, clauses.privateSyms ), clauses.threadLimit ,
1696+ /* private_maps=*/ nullptr );
16601697}
16611698
16621699LogicalResult TargetOp::verify () {
0 commit comments