@@ -396,8 +396,9 @@ static LogicalResult verifyReductionVarList(Operation *op,
396396// /
397397// / hint-clause = `hint` `(` hint-value `)`
398398static ParseResult parseSynchronizationHint (OpAsmParser &parser,
399- IntegerAttr &hintAttr) {
400- if (failed (parser.parseOptionalKeyword (" hint" ))) {
399+ IntegerAttr &hintAttr,
400+ bool parseKeyword = true ) {
401+ if (parseKeyword && failed (parser.parseOptionalKeyword (" hint" ))) {
401402 hintAttr = IntegerAttr::get (parser.getBuilder ().getI64Type (), 0 );
402403 return success ();
403404 }
@@ -455,11 +456,11 @@ static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
455456
456457 p << " hint(" ;
457458 llvm::interleaveComma (hints, p);
458- p << " )" ;
459+ p << " ) " ;
459460}
460461
461462// / Verifies a synchronization hint clause
462- static LogicalResult verifySynchronizationHint (Operation *op, int32_t hint) {
463+ static LogicalResult verifySynchronizationHint (Operation *op, uint64_t hint) {
463464
464465 // Helper function to get n-th bit from the right end of `value`
465466 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
@@ -497,6 +498,8 @@ enum ClauseType {
497498 orderClause,
498499 orderedClause,
499500 inclusiveClause,
501+ memoryOrderClause,
502+ hintClause,
500503 COUNT
501504};
502505
@@ -725,6 +728,19 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
725728 return failure ();
726729 auto attr = UnitAttr::get (parser.getBuilder ().getContext ());
727730 result.addAttribute (" inclusive" , attr);
731+ } else if (clauseKeyword == " memory_order" ) {
732+ StringRef memoryOrder;
733+ if (checkAllowed (memoryOrderClause) || parser.parseLParen () ||
734+ parser.parseKeyword (&memoryOrder) || parser.parseRParen ())
735+ return failure ();
736+ result.addAttribute (" memory_order" ,
737+ parser.getBuilder ().getStringAttr (memoryOrder));
738+ } else if (clauseKeyword == " hint" ) {
739+ IntegerAttr hint;
740+ if (checkAllowed (hintClause) ||
741+ parseSynchronizationHint (parser, hint, false ))
742+ return failure ();
743+ result.addAttribute (" hint" , hint);
728744 } else {
729745 return parser.emitError (parser.getNameLoc ())
730746 << clauseKeyword << " is not a valid clause" ;
@@ -1194,5 +1210,105 @@ static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
11941210 return success ();
11951211}
11961212
1213+ // ===----------------------------------------------------------------------===//
1214+ // AtomicReadOp
1215+ // ===----------------------------------------------------------------------===//
1216+
1217+ // / Parser for AtomicReadOp
1218+ // /
1219+ // / operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1220+ // / address ::= operand `:` type
1221+ static ParseResult parseAtomicReadOp (OpAsmParser &parser,
1222+ OperationState &result) {
1223+ OpAsmParser::OperandType address;
1224+ Type addressType;
1225+ SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1226+ SmallVector<int > segments;
1227+
1228+ if (parser.parseOperand (address) ||
1229+ parseClauses (parser, result, clauses, segments) ||
1230+ parser.parseColonType (addressType) ||
1231+ parser.resolveOperand (address, addressType, result.operands ))
1232+ return failure ();
1233+
1234+ SmallVector<Type> resultType;
1235+ if (parser.parseArrowTypeList (resultType))
1236+ return failure ();
1237+ result.addTypes (resultType);
1238+ return success ();
1239+ }
1240+
1241+ // / Printer for AtomicReadOp
1242+ static void printAtomicReadOp (OpAsmPrinter &p, AtomicReadOp op) {
1243+ p << " " << op.address () << " " ;
1244+ if (op.memory_order ())
1245+ p << " memory_order(" << op.memory_order ().getValue () << " ) " ;
1246+ if (op.hintAttr ())
1247+ printSynchronizationHint (p << " " , op, op.hintAttr ());
1248+ p << " : " << op.address ().getType () << " -> " << op.getType ();
1249+ return ;
1250+ }
1251+
1252+ // / Verifier for AtomicReadOp
1253+ static LogicalResult verifyAtomicReadOp (AtomicReadOp op) {
1254+ if (op.memory_order ()) {
1255+ StringRef memOrder = op.memory_order ().getValue ();
1256+ if (memOrder.equals (" acq_rel" ) || memOrder.equals (" release" ))
1257+ return op.emitError (
1258+ " memory-order must not be acq_rel or release for atomic reads" );
1259+ }
1260+ return verifySynchronizationHint (op, op.hint ());
1261+ }
1262+
1263+ // ===----------------------------------------------------------------------===//
1264+ // AtomicWriteOp
1265+ // ===----------------------------------------------------------------------===//
1266+
1267+ // / Parser for AtomicWriteOp
1268+ // /
1269+ // / operation ::= `omp.atomic.write` atomic-clause-list operands
1270+ // / operands ::= address `,` value
1271+ // / address ::= operand `:` type
1272+ // / value ::= operand `:` type
1273+ static ParseResult parseAtomicWriteOp (OpAsmParser &parser,
1274+ OperationState &result) {
1275+ OpAsmParser::OperandType address, value;
1276+ Type addrType, valueType;
1277+ SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1278+ SmallVector<int > segments;
1279+
1280+ if (parser.parseOperand (address) || parser.parseComma () ||
1281+ parser.parseOperand (value) ||
1282+ parseClauses (parser, result, clauses, segments) ||
1283+ parser.parseColonType (addrType) || parser.parseComma () ||
1284+ parser.parseType (valueType) ||
1285+ parser.resolveOperand (address, addrType, result.operands ) ||
1286+ parser.resolveOperand (value, valueType, result.operands ))
1287+ return failure ();
1288+ return success ();
1289+ }
1290+
1291+ // / Printer for AtomicWriteOp
1292+ static void printAtomicWriteOp (OpAsmPrinter &p, AtomicWriteOp op) {
1293+ p << " " << op.address () << " , " << op.value () << " " ;
1294+ if (op.memory_order ())
1295+ p << " memory_order(" << op.memory_order () << " ) " ;
1296+ if (op.hintAttr ())
1297+ printSynchronizationHint (p, op, op.hintAttr ());
1298+ p << " : " << op.address ().getType () << " , " << op.value ().getType ();
1299+ return ;
1300+ }
1301+
1302+ // / Verifier for AtomicWriteOp
1303+ static LogicalResult verifyAtomicWriteOp (AtomicWriteOp op) {
1304+ if (op.memory_order ()) {
1305+ StringRef memoryOrder = op.memory_order ().getValue ();
1306+ if (memoryOrder.equals (" acq_rel" ) || memoryOrder.equals (" acquire" ))
1307+ return op.emitError (
1308+ " memory-order must not be acq_rel or acquire for atomic writes" );
1309+ }
1310+ return verifySynchronizationHint (op, op.hint ());
1311+ }
1312+
11971313#define GET_OP_CLASSES
11981314#include " mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
0 commit comments