Skip to content

Conversation

@krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented Nov 5, 2024

PR #112292 added support for vectors to the integer range inference interface and analysis, but didn't update the getDestWidth() method. This caused crashes when trying to infer the ranges of arith.extsi with vector inputs, as the code would try to sign-extend a N-bit value to a 0-bit one, which would assert and crash.

This commit fixes the issue by adding a getElementTypeOrSelf().

PR llvm#112292 added support for vectors to the integer range inference
interface and analysis, but didn't update the getDestWidth() method.
This caused crashes when trying to infer the ranges of `arith.extsi`
with vector inputs, as the code would try to sign-extend a N-bit value
to a 0-bit one, which would assert and crash.

This commit fixes the issue by adding a getElementTypeOrSelf().
@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-mlir-vector

Author: Krzysztof Drewniak (krzysz00)

Changes

PR #112292 added support for vectors to the integer range inference interface and analysis, but didn't update the getDestWidth() method. This caused crashes when trying to infer the ranges of arith.extsi with vector inputs, as the code would try to sign-extend a N-bit value to a 0-bit one, which would assert and crash.

This commit fixes the issue by adding a getElementTypeOrSelf().


Full diff: https://github.com/llvm/llvm-project/pull/114898.diff

2 Files Affected:

  • (modified) mlir/lib/Interfaces/InferIntRangeInterface.cpp (+2)
  • (modified) mlir/test/Dialect/Vector/int-range-interface.mlir (+10-1)
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index d879b93586899b..63658518dd4a3b 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
 #include <optional>
 
@@ -28,6 +29,7 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
 
 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
+  type = getElementTypeOrSelf(type);
   if (type.isIndex())
     return IndexType::kInternalStorageBitWidth;
   if (auto integerType = dyn_cast<IntegerType>(type))
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 29282423089ba6..09dfe932a52323 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -96,7 +96,7 @@ func.func @vector_insertelement() -> vector<4xindex> {
 
 // CHECK-LABEL: func @test_loaded_vector_extract
 // No bounds
-// CHECK: test.reflect_bounds %{{.*}} : i32
+// CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32
 func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
   %c0 = arith.constant 0 : index
   %v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32>
@@ -104,3 +104,12 @@ func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
   %bounds = test.reflect_bounds %e : i32
   func.return %bounds : i32
 }
+
+// CHECK-LABEL: func @test_vector_extsi
+// CHECK: test.reflect_bounds {smax = 5 : si32, smin = 1 : si32, umax = 5 : ui32, umin = 1 : ui32}
+func.func @test_vector_extsi() -> vector<2xi32> {
+  %0 = test.with_bounds {smax = 5 : si8, smin = 1 : si8, umax = 5 : ui8, umin = 1 : ui8 } : vector<2xi8>
+  %1 = arith.extsi %0 : vector<2xi8> to vector<2xi32>
+  %2 = test.reflect_bounds %1 : vector<2xi32>
+  func.return %2 : vector<2xi32>
+}

@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

PR #112292 added support for vectors to the integer range inference interface and analysis, but didn't update the getDestWidth() method. This caused crashes when trying to infer the ranges of arith.extsi with vector inputs, as the code would try to sign-extend a N-bit value to a 0-bit one, which would assert and crash.

This commit fixes the issue by adding a getElementTypeOrSelf().


Full diff: https://github.com/llvm/llvm-project/pull/114898.diff

2 Files Affected:

  • (modified) mlir/lib/Interfaces/InferIntRangeInterface.cpp (+2)
  • (modified) mlir/test/Dialect/Vector/int-range-interface.mlir (+10-1)
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index d879b93586899b..63658518dd4a3b 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
 #include <optional>
 
@@ -28,6 +29,7 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
 
 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
+  type = getElementTypeOrSelf(type);
   if (type.isIndex())
     return IndexType::kInternalStorageBitWidth;
   if (auto integerType = dyn_cast<IntegerType>(type))
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 29282423089ba6..09dfe932a52323 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -96,7 +96,7 @@ func.func @vector_insertelement() -> vector<4xindex> {
 
 // CHECK-LABEL: func @test_loaded_vector_extract
 // No bounds
-// CHECK: test.reflect_bounds %{{.*}} : i32
+// CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32
 func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
   %c0 = arith.constant 0 : index
   %v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32>
@@ -104,3 +104,12 @@ func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
   %bounds = test.reflect_bounds %e : i32
   func.return %bounds : i32
 }
+
+// CHECK-LABEL: func @test_vector_extsi
+// CHECK: test.reflect_bounds {smax = 5 : si32, smin = 1 : si32, umax = 5 : ui32, umin = 1 : ui32}
+func.func @test_vector_extsi() -> vector<2xi32> {
+  %0 = test.with_bounds {smax = 5 : si8, smin = 1 : si8, umax = 5 : ui8, umin = 1 : ui8 } : vector<2xi8>
+  %1 = arith.extsi %0 : vector<2xi8> to vector<2xi32>
+  %2 = test.reflect_bounds %1 : vector<2xi32>
+  func.return %2 : vector<2xi32>
+}

@krzysz00 krzysz00 merged commit 616aff1 into llvm:main Nov 5, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants