diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8d6e263934fb4..347141e2773b8 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -397,15 +397,13 @@ std::optional vector::getConstantVscaleMultiplier(Value value) { } /// Converts an IntegerAttr to have the specified type if needed. -/// This handles cases where constant attributes have a different type than the -/// target element type. If the input attribute is not an IntegerAttr or already -/// has the correct type, returns it unchanged. -static Attribute convertIntegerAttr(Attribute attr, Type expectedType) { - if (auto intAttr = mlir::dyn_cast(attr)) { - if (intAttr.getType() != expectedType) - return IntegerAttr::get(expectedType, intAttr.getInt()); - } - return attr; +/// This handles cases where integer constant attributes have a different type +/// than the target element type. +static IntegerAttr convertIntegerAttr(IntegerAttr intAttr, Type expectedType) { + if (intAttr.getType() == expectedType) + return intAttr; // Already correct type + + return IntegerAttr::get(expectedType, intAttr.getInt()); } //===----------------------------------------------------------------------===// @@ -2463,7 +2461,10 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) { /// static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef elements) { - if (llvm::any_of(elements, [](Attribute attr) { return !attr; })) + // Check for null or poison attributes before any processing. + if (llvm::any_of(elements, [](Attribute attr) { + return !attr || isa(attr); + })) return {}; // DenseElementsAttr only supports int/index/float/complex types. @@ -2472,11 +2473,16 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, if (!destEltType.isIntOrIndexOrFloat() && !isa(destEltType)) return {}; - // Constant attributes might have a different type than the return type. - // Convert them before creating the dense elements attribute. - auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) { - return convertIntegerAttr(attr, destEltType); - }); + // Convert integer attributes to the target type if needed, leave others + // unchanged. + auto convertedElements = + llvm::map_to_vector(elements, [&](Attribute attr) -> Attribute { + if (auto intAttr = dyn_cast(attr)) { + return convertIntegerAttr(intAttr, destEltType); + } + return attr; // Non-integer attributes (FloatAttr, etc.) returned + // unchanged + }); return DenseElementsAttr::get(destVecType, convertedElements); } @@ -3497,13 +3503,19 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, SmallVector insertedValues; Type destEltType = destTy.getElementType(); - /// Converts the expected type to an IntegerAttr if there's - /// a mismatch. + /// Converts integer attributes to the expected type if there's a mismatch. + /// Non-integer attributes are left unchanged. if (auto denseSource = llvm::dyn_cast(srcAttr)) { for (auto value : denseSource.getValues()) - insertedValues.push_back(convertIntegerAttr(value, destEltType)); + if (auto intAttr = dyn_cast(value)) + insertedValues.push_back(convertIntegerAttr(intAttr, destEltType)); + else + insertedValues.push_back(value); // Non-integer attributes unchanged } else { - insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType)); + if (auto intAttr = dyn_cast(srcAttr)) + insertedValues.push_back(convertIntegerAttr(intAttr, destEltType)); + else + insertedValues.push_back(srcAttr); // Non-integer attributes unchanged } auto allValues = llvm::to_vector(denseDst.getValues()); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 05c88b8abfbb0..08d28be3f8f73 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3375,6 +3375,42 @@ func.func @negative_from_elements_to_constant() -> vector<1x!llvm.ptr> { return %b : vector<1x!llvm.ptr> } +// ----- + +// CHECK-LABEL: @negative_from_elements_poison +// CHECK: %[[VAL:.*]] = ub.poison : vector<2xf32> +// CHECK: return %[[VAL]] : vector<2xf32> +func.func @negative_from_elements_poison_f32() -> vector<2xf32> { + %0 = ub.poison : f32 + %1 = vector.from_elements %0, %0 : vector<2xf32> + return %1 : vector<2xf32> +} + +// ----- + +// CHECK-LABEL: @negative_from_elements_poison_i32 +// CHECK: %[[VAL:.*]] = ub.poison : vector<2xi32> +// CHECK: return %[[VAL]] : vector<2xi32> +func.func @negative_from_elements_poison_i32() -> vector<2xi32> { + %0 = ub.poison : i32 + %1 = vector.from_elements %0, %0 : vector<2xi32> + return %1 : vector<2xi32> +} + +// ----- + +// CHECK-LABEL: @negative_from_elements_poison_constant_mix +// CHECK: %[[POISON:.*]] = ub.poison : f32 +// CHECK: %[[CONST:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[RES:.*]] = vector.from_elements %[[POISON]], %[[CONST]] : vector<2xf32> +// CHECK: return %[[RES]] : vector<2xf32> +func.func @negative_from_elements_poison_constant_mix() -> vector<2xf32> { + %0 = ub.poison : f32 + %c = arith.constant 1.0 : f32 + %1 = vector.from_elements %0, %c : vector<2xf32> + return %1 : vector<2xf32> +} + // +--------------------------------------------------------------------------- // End of Tests for foldFromElementsToConstant // +---------------------------------------------------------------------------