Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,15 +397,13 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
}

/// Converts an IntegerAttr to have the specified type if needed.
/// This handles cases where constant attributes have a different type than the
/// target element type. If the input attribute is not an IntegerAttr or already
/// has the correct type, returns it unchanged.
static Attribute convertIntegerAttr(Attribute attr, Type expectedType) {
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
if (intAttr.getType() != expectedType)
return IntegerAttr::get(expectedType, intAttr.getInt());
}
return attr;
/// This handles cases where integer constant attributes have a different type
/// than the target element type.
static IntegerAttr convertIntegerAttr(IntegerAttr intAttr, Type expectedType) {
if (intAttr.getType() == expectedType)
return intAttr; // Already correct type

return IntegerAttr::get(expectedType, intAttr.getInt());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2463,7 +2461,10 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
///
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
ArrayRef<Attribute> elements) {
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
// Check for null or poison attributes before any processing.
if (llvm::any_of(elements, [](Attribute attr) {
return !attr || isa<ub::PoisonAttrInterface>(attr);
}))
return {};

// DenseElementsAttr only supports int/index/float/complex types.
Expand All @@ -2472,11 +2473,16 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
return {};

// Constant attributes might have a different type than the return type.
// Convert them before creating the dense elements attribute.
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
return convertIntegerAttr(attr, destEltType);
});
// Convert integer attributes to the target type if needed, leave others
// unchanged.
auto convertedElements =
llvm::map_to_vector(elements, [&](Attribute attr) -> Attribute {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
return convertIntegerAttr(intAttr, destEltType);
}
return attr; // Non-integer attributes (FloatAttr, etc.) returned
// unchanged
});

return DenseElementsAttr::get(destVecType, convertedElements);
}
Expand Down Expand Up @@ -3497,13 +3503,19 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
SmallVector<Attribute> insertedValues;
Type destEltType = destTy.getElementType();

/// Converts the expected type to an IntegerAttr if there's
/// a mismatch.
/// Converts integer attributes to the expected type if there's a mismatch.
/// Non-integer attributes are left unchanged.
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
for (auto value : denseSource.getValues<Attribute>())
insertedValues.push_back(convertIntegerAttr(value, destEltType));
if (auto intAttr = dyn_cast<IntegerAttr>(value))
insertedValues.push_back(convertIntegerAttr(intAttr, destEltType));
else
insertedValues.push_back(value); // Non-integer attributes unchanged
} else {
insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
insertedValues.push_back(convertIntegerAttr(intAttr, destEltType));
else
insertedValues.push_back(srcAttr); // Non-integer attributes unchanged
}

auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3375,6 +3375,42 @@ func.func @negative_from_elements_to_constant() -> vector<1x!llvm.ptr> {
return %b : vector<1x!llvm.ptr>
}

// -----

// CHECK-LABEL: @negative_from_elements_poison
// CHECK: %[[VAL:.*]] = ub.poison : vector<2xf32>
// CHECK: return %[[VAL]] : vector<2xf32>
func.func @negative_from_elements_poison_f32() -> vector<2xf32> {
%0 = ub.poison : f32
%1 = vector.from_elements %0, %0 : vector<2xf32>
return %1 : vector<2xf32>
}

// -----

// CHECK-LABEL: @negative_from_elements_poison_i32
// CHECK: %[[VAL:.*]] = ub.poison : vector<2xi32>
// CHECK: return %[[VAL]] : vector<2xi32>
func.func @negative_from_elements_poison_i32() -> vector<2xi32> {
%0 = ub.poison : i32
%1 = vector.from_elements %0, %0 : vector<2xi32>
return %1 : vector<2xi32>
}

// -----

// CHECK-LABEL: @negative_from_elements_poison_constant_mix
// CHECK: %[[POISON:.*]] = ub.poison : f32
// CHECK: %[[CONST:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[RES:.*]] = vector.from_elements %[[POISON]], %[[CONST]] : vector<2xf32>
// CHECK: return %[[RES]] : vector<2xf32>
func.func @negative_from_elements_poison_constant_mix() -> vector<2xf32> {
%0 = ub.poison : f32
%c = arith.constant 1.0 : f32
%1 = vector.from_elements %0, %c : vector<2xf32>
return %1 : vector<2xf32>
}

// +---------------------------------------------------------------------------
// End of Tests for foldFromElementsToConstant
// +---------------------------------------------------------------------------
Expand Down