diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp index 8682294c8a697..f3413c1c30fad 100644 --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp @@ -42,6 +42,12 @@ void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, } if (auto arrayCstAttr = llvm::dyn_cast_or_null(getValue())) { + if (arrayCstAttr.isSplat()) { + setResultRange(getResult(), ConstantIntRanges::constant( + arrayCstAttr.getSplatValue())); + return; + } + std::optional result; for (const APInt &val : arrayCstAttr) { auto range = ConstantIntRanges::constant(val); diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index 09dfe932a5232..0263193b20401 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -17,6 +17,13 @@ func.func @constant_splat() -> vector<8xi32> { func.return %1 : vector<8xi32> } +// CHECK-LABEL: func @float_constant_splat +// Don't crash on splat floats. +func.func @float_constant_splat() -> vector<8xf32> { + %0 = arith.constant dense<3.0> : vector<8xf32> + func.return %0: vector<8xf32> +} + // CHECK-LABEL: func @vector_splat // CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} func.func @vector_splat() -> vector<4xindex> {