Skip to content

Conversation

@CoTinker
Copy link
Contributor

This PR adds a check in the parser to prevent a crash when vector.contract lacks the iterator_types attribute.
Fixes #132886.

@llvmbot
Copy link
Member

llvmbot commented Mar 28, 2025

@llvm/pr-subscribers-mlir-vector

Author: Longsheng Mou (CoTinker)

Changes

This PR adds a check in the parser to prevent a crash when vector.contract lacks the iterator_types attribute.
Fixes #132886.


Full diff: https://github.com/llvm/llvm-project/pull/133434.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-1)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+8)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eccb3e578458e..5a3983699d5a3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -787,8 +787,13 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
   // because tests still use the old format when 'iterator_types' attribute is
   // represented as an array of strings.
   // TODO: Remove this conversion once tests are fixed.
-  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
+  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
       result.attributes.get(getIteratorTypesAttrName(result.name)));
+  if (!iteratorTypes) {
+    return parser.emitError(loc)
+           << "expected " << getIteratorTypesAttrName(result.name)
+           << " array attribute";
+  }
 
   SmallVector<Attribute> iteratorTypeAttrs;
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1b89e8eb5069b..ad18cfda6fe83 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1015,6 +1015,14 @@ func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg
 
 // -----
 
+func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector<2xi32>, %arg2: vector<1xi32>) -> vector<1xi32> {
+  // expected-error@+1 {{'vector.contract' op expected "iterator_types" array attribute}}
+  %0 = vector.contract {} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2xi32> into vector<1xi32>
+  return %0 : vector<1xi32>
+}
+
+// -----
+
 func.func @create_mask_0d_no_operands() {
   %c1 = arith.constant 1 : index
   // expected-error@+1 {{must specify exactly one operand for 0-D create_mask}}

@llvmbot
Copy link
Member

llvmbot commented Mar 28, 2025

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR adds a check in the parser to prevent a crash when vector.contract lacks the iterator_types attribute.
Fixes #132886.


Full diff: https://github.com/llvm/llvm-project/pull/133434.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-1)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+8)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eccb3e578458e..5a3983699d5a3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -787,8 +787,13 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
   // because tests still use the old format when 'iterator_types' attribute is
   // represented as an array of strings.
   // TODO: Remove this conversion once tests are fixed.
-  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
+  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
       result.attributes.get(getIteratorTypesAttrName(result.name)));
+  if (!iteratorTypes) {
+    return parser.emitError(loc)
+           << "expected " << getIteratorTypesAttrName(result.name)
+           << " array attribute";
+  }
 
   SmallVector<Attribute> iteratorTypeAttrs;
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1b89e8eb5069b..ad18cfda6fe83 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1015,6 +1015,14 @@ func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg
 
 // -----
 
+func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector<2xi32>, %arg2: vector<1xi32>) -> vector<1xi32> {
+  // expected-error@+1 {{'vector.contract' op expected "iterator_types" array attribute}}
+  %0 = vector.contract {} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2xi32> into vector<1xi32>
+  return %0 : vector<1xi32>
+}
+
+// -----
+
 func.func @create_mask_0d_no_operands() {
   %c1 = arith.constant 1 : index
   // expected-error@+1 {{must specify exactly one operand for 0-D create_mask}}

This PR adds a check in the parser to prevent a crash when
`vector.contract` lacks the `iterator_types` attribute.
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@CoTinker CoTinker merged commit 22a11be into llvm:main Mar 29, 2025
11 checks passed
@CoTinker CoTinker deleted the parse_contract branch March 29, 2025 01:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir] Parser crash 'vector.contract'

4 participants