diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eccb3e578458e..5a3983699d5a3 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -787,8 +787,13 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { // because tests still use the old format when 'iterator_types' attribute is // represented as an array of strings. // TODO: Remove this conversion once tests are fixed. - ArrayAttr iteratorTypes = llvm::cast( + auto iteratorTypes = dyn_cast_or_null( result.attributes.get(getIteratorTypesAttrName(result.name))); + if (!iteratorTypes) { + return parser.emitError(loc) + << "expected " << getIteratorTypesAttrName(result.name) + << " array attribute"; + } SmallVector iteratorTypeAttrs; diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 1b89e8eb5069b..ea6d0021391fb 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1015,6 +1015,14 @@ func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg // ----- +func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector<2xi32>, %arg2: vector<1xi32>) -> vector<1xi32> { + // expected-error@+1 {{'vector.contract' expected "iterator_types" array attribute}} + %0 = vector.contract {} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2xi32> into vector<1xi32> + return %0 : vector<1xi32> +} + +// ----- + func.func @create_mask_0d_no_operands() { %c1 = arith.constant 1 : index // expected-error@+1 {{must specify exactly one operand for 0-D create_mask}}