@@ -470,13 +470,17 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
470470 ValueRange argsSubrange,
471471 StringRef clauseName, ValueRange operands,
472472 TypeRange types, ArrayAttr symbols) {
473- p << clauseName << " (" ;
473+ if (!clauseName.empty ())
474+ p << clauseName << " (" ;
475+
474476 llvm::interleaveComma (
475477 llvm::zip_equal (symbols, operands, argsSubrange, types), p, [&p](auto t) {
476478 auto [sym, op, arg, type] = t;
477479 p << sym << " " << op << " -> " << arg << " : " << type;
478480 });
479- p << " ) " ;
481+
482+ if (!clauseName.empty ())
483+ p << " ) " ;
480484}
481485
482486static ParseResult parseParallelRegion (
@@ -1048,6 +1052,49 @@ static void printMapEntries(OpAsmPrinter &p, Operation *op,
10481052 }
10491053}
10501054
1055+ static ParseResult parsePrivateList (
1056+ OpAsmParser &parser,
1057+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
1058+ SmallVectorImpl<Type> &privateOperandTypes, ArrayAttr &privatizerSymbols) {
1059+ SmallVector<SymbolRefAttr> privateSymRefs;
1060+ SmallVector<OpAsmParser::Argument> regionPrivateArgs;
1061+
1062+ if (failed (parser.parseCommaSeparatedList ([&]() {
1063+ if (parser.parseAttribute (privateSymRefs.emplace_back ()) ||
1064+ parser.parseOperand (privateOperands.emplace_back ()) ||
1065+ parser.parseArrow () ||
1066+ parser.parseArgument (regionPrivateArgs.emplace_back ()) ||
1067+ parser.parseColonType (privateOperandTypes.emplace_back ()))
1068+ return failure ();
1069+ return success ();
1070+ })))
1071+ return failure ();
1072+
1073+ SmallVector<Attribute> privateSymAttrs (privateSymRefs.begin (),
1074+ privateSymRefs.end ());
1075+ privatizerSymbols = ArrayAttr::get (parser.getContext (), privateSymAttrs);
1076+
1077+ return success ();
1078+ }
1079+
1080+ static void printPrivateList (OpAsmPrinter &p, Operation *op,
1081+ ValueRange privateVarOperands,
1082+ TypeRange privateVarTypes,
1083+ ArrayAttr privatizerSymbols) {
1084+ // TODO: Remove target-specific logic from this function.
1085+ auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
1086+ assert (targetOp);
1087+
1088+ auto ®ion = op->getRegion (0 );
1089+ auto *argsBegin = region.front ().getArguments ().begin ();
1090+ MutableArrayRef argsSubrange (argsBegin + targetOp.getMapOperands ().size (),
1091+ argsBegin + targetOp.getMapOperands ().size () +
1092+ privateVarTypes.size ());
1093+ printClauseWithRegionArgs (
1094+ p, op, argsSubrange, /* clauseName=*/ llvm::StringRef{}, privateVarOperands,
1095+ privateVarTypes, privatizerSymbols);
1096+ }
1097+
10511098static void printCaptureType (OpAsmPrinter &p, Operation *op,
10521099 VariableCaptureKindAttr mapCaptureType) {
10531100 std::string typeCapStr;
@@ -1256,13 +1303,14 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
12561303 const TargetClauseOps &clauses) {
12571304 MLIRContext *ctx = builder.getContext ();
12581305 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1259- // inReductionDeclSymbols, privateVars, privatizers, reductionVars ,
1260- // reductionByRefAttr, reductionDeclSymbols.
1306+ // inReductionDeclSymbols, reductionVars, reductionByRefAttr ,
1307+ // reductionDeclSymbols.
12611308 TargetOp::build (
12621309 builder, state, clauses.ifVar , clauses.deviceVar , clauses.threadLimitVar ,
12631310 makeArrayAttr (ctx, clauses.dependTypeAttrs ), clauses.dependVars ,
12641311 clauses.nowaitAttr , clauses.isDevicePtrVars , clauses.hasDeviceAddrVars ,
1265- clauses.mapVars );
1312+ clauses.mapVars , clauses.privateVars ,
1313+ makeArrayAttr (ctx, clauses.privatizers ));
12661314}
12671315
12681316LogicalResult TargetOp::verify () {
0 commit comments