diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 877bd226a0352..919a0853fb604 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -759,20 +759,23 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch( Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) { std::size_t numOperandsInSegments = 0; - - if (!segments) - return success(); - - for (auto segCount : segments.asArrayRef()) { - if (maxInSegment != 0 && segCount > maxInSegment) - return op.emitOpError() << keyword << " expects a maximum of " - << maxInSegment << " values per segment"; - numOperandsInSegments += segCount; + std::size_t nbOfSegments = 0; + + if (segments) { + for (auto segCount : segments.asArrayRef()) { + if (maxInSegment != 0 && segCount > maxInSegment) + return op.emitOpError() << keyword << " expects a maximum of " + << maxInSegment << " values per segment"; + numOperandsInSegments += segCount; + ++nbOfSegments; + } } - if (numOperandsInSegments != operands.size()) + + if ((numOperandsInSegments != operands.size()) || + (!deviceTypes && !operands.empty())) return op.emitOpError() << keyword << " operand count does not match count in segments"; - if (deviceTypes.getValue().size() != (size_t)segments.size()) + if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments) return op.emitOpError() << keyword << " segment count does not match device_type count"; return success(); diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index ec5430420524c..96edb585ae21a 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -507,6 +507,13 @@ acc.parallel num_gangs({%i64value: i64, %i64value : i64, %i64value : i64, %i64va // ----- +%0 = "arith.constant"() <{value = 1 : i64}> : () -> i64 +// expected-error@+1 {{num_gangs operand count does not match count in segments}} +"acc.parallel"(%0) <{numGangsSegments = array, operandSegmentSizes = array}> ({ +}) : (i64) -> () + +// ----- + %i64value = arith.constant 1 : i64 acc.parallel { // expected-error@+1 {{'acc.set' op cannot be nested in a compute operation}}