@@ -508,6 +508,7 @@ struct ReductionParseArgs {
508508};
509509
510510struct AllRegionParseArgs {
511+ std::optional<MapParseArgs> hasDeviceAddrArgs;
511512 std::optional<MapParseArgs> hostEvalArgs;
512513 std::optional<ReductionParseArgs> inReductionArgs;
513514 std::optional<MapParseArgs> mapArgs;
@@ -666,6 +667,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
666667 AllRegionParseArgs args) {
667668 llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
668669
670+ if (failed (parseBlockArgClause (parser, entryBlockArgs, " has_device_addr" ,
671+ args.hasDeviceAddrArgs )))
672+ return parser.emitError (parser.getCurrentLocation ())
673+ << " invalid `has_device_addr` format" ;
674+
669675 if (failed (parseBlockArgClause (parser, entryBlockArgs, " host_eval" ,
670676 args.hostEvalArgs )))
671677 return parser.emitError (parser.getCurrentLocation ())
@@ -709,8 +715,12 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
709715 return parser.parseRegion (region, entryBlockArgs);
710716}
711717
712- static ParseResult parseHostEvalInReductionMapPrivateRegion (
718+ // See custom<HasDeviceAddrHostEvalInReductionMapPrivateRegion> in the
719+ // definition of TargetOp.
720+ static ParseResult parseHasDeviceAddrHostEvalInReductionMapPrivateRegion (
713721 OpAsmParser &parser, Region ®ion,
722+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hasDeviceAddrVars,
723+ SmallVectorImpl<Type> &hasDeviceAddrTypes,
714724 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
715725 SmallVectorImpl<Type> &hostEvalTypes,
716726 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
@@ -722,6 +732,7 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
722732 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
723733 DenseI64ArrayAttr &privateMaps) {
724734 AllRegionParseArgs args;
735+ args.hasDeviceAddrArgs .emplace (hasDeviceAddrVars, hasDeviceAddrTypes);
725736 args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
726737 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
727738 inReductionByref, inReductionSyms);
@@ -731,6 +742,7 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
731742 return parseBlockArgRegion (parser, region, args);
732743}
733744
745+ // See custom<InReductionPrivateRegion> in the definition of TaskOp.
734746static ParseResult parseInReductionPrivateRegion (
735747 OpAsmParser &parser, Region ®ion,
736748 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
@@ -745,6 +757,8 @@ static ParseResult parseInReductionPrivateRegion(
745757 return parseBlockArgRegion (parser, region, args);
746758}
747759
760+ // See custom<InReductionPrivateReductionRegion> in the definition of
761+ // TaskloopOp.
748762static ParseResult parseInReductionPrivateReductionRegion (
749763 OpAsmParser &parser, Region ®ion,
750764 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
@@ -765,6 +779,7 @@ static ParseResult parseInReductionPrivateReductionRegion(
765779 return parseBlockArgRegion (parser, region, args);
766780}
767781
782+ // See custom<PrivateRegion> in the definition of SingleOp.
768783static ParseResult parsePrivateRegion (
769784 OpAsmParser &parser, Region ®ion,
770785 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
@@ -774,6 +789,7 @@ static ParseResult parsePrivateRegion(
774789 return parseBlockArgRegion (parser, region, args);
775790}
776791
792+ // See custom<PrivateReductionRegion> in the definition of LoopOp.
777793static ParseResult parsePrivateReductionRegion (
778794 OpAsmParser &parser, Region ®ion,
779795 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
@@ -789,6 +805,7 @@ static ParseResult parsePrivateReductionRegion(
789805 return parseBlockArgRegion (parser, region, args);
790806}
791807
808+ // See custom<TaskReductionRegion> in the definition of TaskgroupOp.
792809static ParseResult parseTaskReductionRegion (
793810 OpAsmParser &parser, Region ®ion,
794811 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
@@ -800,6 +817,8 @@ static ParseResult parseTaskReductionRegion(
800817 return parseBlockArgRegion (parser, region, args);
801818}
802819
820+ // See custom<UseDeviceAddrUseDevicePtrRegion> in the definition of
821+ // TargetDataOp.
803822static ParseResult parseUseDeviceAddrUseDevicePtrRegion (
804823 OpAsmParser &parser, Region ®ion,
805824 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
@@ -842,6 +861,7 @@ struct ReductionPrintArgs {
842861 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
843862};
844863struct AllRegionPrintArgs {
864+ std::optional<MapPrintArgs> hasDeviceAddrArgs;
845865 std::optional<MapPrintArgs> hostEvalArgs;
846866 std::optional<ReductionPrintArgs> inReductionArgs;
847867 std::optional<MapPrintArgs> mapArgs;
@@ -935,6 +955,9 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
935955 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
936956 MLIRContext *ctx = op->getContext ();
937957
958+ printBlockArgClause (p, ctx, " has_device_addr" ,
959+ iface.getHasDeviceAddrBlockArgs (),
960+ args.hasDeviceAddrArgs );
938961 printBlockArgClause (p, ctx, " host_eval" , iface.getHostEvalBlockArgs (),
939962 args.hostEvalArgs );
940963 printBlockArgClause (p, ctx, " in_reduction" , iface.getInReductionBlockArgs (),
@@ -957,14 +980,19 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
957980 p.printRegion (region, /* printEntryBlockArgs=*/ false );
958981}
959982
960- static void printHostEvalInReductionMapPrivateRegion (
961- OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars,
962- TypeRange hostEvalTypes, ValueRange inReductionVars,
963- TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
964- ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
965- ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
983+ // See custom<HasDeviceAddrHostEvalInReductionMapPrivateRegion> in the
984+ // definition of TargetOp.
985+ static void printHasDeviceAddrHostEvalInReductionMapPrivateRegion (
986+ OpAsmPrinter &p, Operation *op, Region ®ion,
987+ ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
988+ ValueRange hostEvalVars, TypeRange hostEvalTypes,
989+ ValueRange inReductionVars, TypeRange inReductionTypes,
990+ DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
991+ ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
992+ TypeRange privateTypes, ArrayAttr privateSyms,
966993 DenseI64ArrayAttr privateMaps) {
967994 AllRegionPrintArgs args;
995+ args.hasDeviceAddrArgs .emplace (hasDeviceAddrVars, hasDeviceAddrTypes);
968996 args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
969997 args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
970998 inReductionByref, inReductionSyms);
@@ -973,6 +1001,7 @@ static void printHostEvalInReductionMapPrivateRegion(
9731001 printBlockArgRegion (p, op, region, args);
9741002}
9751003
1004+ // See custom<InReductionPrivateRegion> in the definition of TaskOp.
9761005static void printInReductionPrivateRegion (
9771006 OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
9781007 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
@@ -986,6 +1015,8 @@ static void printInReductionPrivateRegion(
9861015 printBlockArgRegion (p, op, region, args);
9871016}
9881017
1018+ // See custom<InReductionPrivateReductionRegion> in the definition of
1019+ // TaskloopOp.
9891020static void printInReductionPrivateReductionRegion (
9901021 OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
9911022 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
@@ -1003,6 +1034,7 @@ static void printInReductionPrivateReductionRegion(
10031034 printBlockArgRegion (p, op, region, args);
10041035}
10051036
1037+ // See custom<PrivateRegion> in the definition of SingleOp.
10061038static void printPrivateRegion (OpAsmPrinter &p, Operation *op, Region ®ion,
10071039 ValueRange privateVars, TypeRange privateTypes,
10081040 ArrayAttr privateSyms) {
@@ -1012,6 +1044,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
10121044 printBlockArgRegion (p, op, region, args);
10131045}
10141046
1047+ // See custom<PrivateReductionRegion> in the definition of LoopOp.
10151048static void printPrivateReductionRegion (
10161049 OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars,
10171050 TypeRange privateTypes, ArrayAttr privateSyms,
@@ -1026,6 +1059,7 @@ static void printPrivateReductionRegion(
10261059 printBlockArgRegion (p, op, region, args);
10271060}
10281061
1062+ // See custom<TaskReductionRegion> in the definition of TaskgroupOp.
10291063static void printTaskReductionRegion (OpAsmPrinter &p, Operation *op,
10301064 Region ®ion,
10311065 ValueRange taskReductionVars,
@@ -1038,6 +1072,8 @@ static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
10381072 printBlockArgRegion (p, op, region, args);
10391073}
10401074
1075+ // See custom<UseDeviceAddrUseDevicePtrRegion> in the definition of
1076+ // TargetDataOp.
10411077static void printUseDeviceAddrUseDevicePtrRegion (OpAsmPrinter &p, Operation *op,
10421078 Region ®ion,
10431079 ValueRange useDeviceAddrVars,
0 commit comments