diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 41cec89fdf598..df7ff28ca5926 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -2197,7 +2197,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", let hasCustomAssemblyFormat = 1; let assemblyFormat = [{ - custom($combined) + ( `combined` `(` custom($combined)^ `)` )? oilist( `gang` `` custom($gangOperands, type($gangOperands), $gangOperandsArgType, $gangOperandsDeviceType, diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index c4cc560e42f6a..91025e90b8e76 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1686,25 +1686,19 @@ static void printDeviceTypeOperandsWithKeywordOnly( static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr) { - if (succeeded(parser.parseOptionalKeyword("combined"))) { - if (parser.parseLParen()) - return failure(); - if (succeeded(parser.parseOptionalKeyword("kernels"))) { - attr = mlir::acc::CombinedConstructsTypeAttr::get( - parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop); - } else if (succeeded(parser.parseOptionalKeyword("parallel"))) { - attr = mlir::acc::CombinedConstructsTypeAttr::get( - parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop); - } else if (succeeded(parser.parseOptionalKeyword("serial"))) { - attr = mlir::acc::CombinedConstructsTypeAttr::get( - parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop); - } else { - parser.emitError(parser.getCurrentLocation(), - "expected compute construct name"); - return failure(); - } - if (parser.parseRParen()) - return failure(); + if (succeeded(parser.parseOptionalKeyword("kernels"))) { + attr = mlir::acc::CombinedConstructsTypeAttr::get( + parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop); + } else if (succeeded(parser.parseOptionalKeyword("parallel"))) { + attr = mlir::acc::CombinedConstructsTypeAttr::get( + parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop); + } else if (succeeded(parser.parseOptionalKeyword("serial"))) { + attr = mlir::acc::CombinedConstructsTypeAttr::get( + parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop); + } else { + parser.emitError(parser.getCurrentLocation(), + "expected compute construct name"); + return failure(); } return success(); } @@ -1715,13 +1709,13 @@ printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, if (attr) { switch (attr.getValue()) { case mlir::acc::CombinedConstructsType::KernelsLoop: - p << "combined(kernels)"; + p << "kernels"; break; case mlir::acc::CombinedConstructsType::ParallelLoop: - p << "combined(parallel)"; + p << "parallel"; break; case mlir::acc::CombinedConstructsType::SerialLoop: - p << "combined(serial)"; + p << "serial"; break; }; }