Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,14 +396,31 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
return {};
}

/// Converts an IntegerAttr to have the specified type if needed.
/// This handles cases where integer constant attributes have a different type
/// than the target element type.
static IntegerAttr convertIntegerAttr(IntegerAttr intAttr, Type expectedType) {
if (intAttr.getType() == expectedType)
return intAttr; // Already correct type
/// Converts numeric attributes to the expected type. Supports
/// integer-to-integer and float-to-integer conversions. Returns the original
/// attribute if no conversion is needed or supported.
static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
// Integer-to-integer conversion
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
if (auto intType = dyn_cast<IntegerType>(expectedType)) {
if (intAttr.getType() != expectedType)
return IntegerAttr::get(expectedType, intAttr.getInt());
}
return attr;
}

// Float-to-integer bitcast (preserves bit representation)
if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
auto intType = dyn_cast<IntegerType>(expectedType);
if (!intType)
return attr;

APFloat floatVal = floatAttr.getValue();
APInt intVal = floatVal.bitcastToAPInt();
return IntegerAttr::get(expectedType, intVal);
}

return IntegerAttr::get(expectedType, intAttr.getInt());
return attr;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2473,16 +2490,11 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
return {};

// Convert integer attributes to the target type if needed, leave others
// unchanged.
auto convertedElements =
llvm::map_to_vector(elements, [&](Attribute attr) -> Attribute {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
return convertIntegerAttr(intAttr, destEltType);
}
return attr; // Non-integer attributes (FloatAttr, etc.) returned
// unchanged
});
// Constant attributes might have a different type than the return type.
// Convert them before creating the dense elements attribute.
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
return convertNumericAttr(attr, destEltType);
});

return DenseElementsAttr::get(destVecType, convertedElements);
}
Expand Down Expand Up @@ -3503,19 +3515,13 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
SmallVector<Attribute> insertedValues;
Type destEltType = destTy.getElementType();

/// Converts integer attributes to the expected type if there's a mismatch.
/// Non-integer attributes are left unchanged.
/// Converts attribute to the expected type if there's
/// a mismatch.
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
for (auto value : denseSource.getValues<Attribute>())
if (auto intAttr = dyn_cast<IntegerAttr>(value))
insertedValues.push_back(convertIntegerAttr(intAttr, destEltType));
else
insertedValues.push_back(value); // Non-integer attributes unchanged
insertedValues.push_back(convertNumericAttr(value, destEltType));
} else {
if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
insertedValues.push_back(convertIntegerAttr(intAttr, destEltType));
else
insertedValues.push_back(srcAttr); // Non-integer attributes unchanged
insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
}

auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3411,6 +3411,74 @@ func.func @negative_from_elements_poison_constant_mix() -> vector<2xf32> {
return %1 : vector<2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_float8_to_i8_conversion(
// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0, 56, -72, 69, 127, -1]> : vector<6xi8>
// CHECK-NEXT: return %[[CST]] : vector<6xi8>
func.func @from_elements_float8_to_i8_conversion() -> vector<6xi8> {
%cst0 = llvm.mlir.constant(0.0 : f8E4M3FN) : i8
%cst1 = llvm.mlir.constant(1.0 : f8E4M3FN) : i8
%cst_neg1 = llvm.mlir.constant(-1.0 : f8E4M3FN) : i8
%cst_pi = llvm.mlir.constant(3.14 : f8E4M3FN) : i8
%cst_inf = llvm.mlir.constant(0x7F : f8E4M3FN) : i8
%cst_neg_inf = llvm.mlir.constant(0xFF : f8E4M3FN) : i8
%v = vector.from_elements %cst0, %cst1, %cst_neg1, %cst_pi, %cst_inf, %cst_neg_inf : vector<6xi8>
return %v : vector<6xi8>
}

// CHECK-LABEL: func @from_elements_float16_to_i16_conversion(
// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0, 15360, -17408, 16968, 31743, -1025]> : vector<6xi16>
// CHECK-NEXT: return %[[CST]] : vector<6xi16>
func.func @from_elements_float16_to_i16_conversion() -> vector<6xi16> {
%cst0 = llvm.mlir.constant(0.0 : f16) : i16
%cst1 = llvm.mlir.constant(1.0 : f16) : i16
%cst_neg1 = llvm.mlir.constant(-1.0 : f16) : i16
%cst_pi = llvm.mlir.constant(3.14 : f16) : i16
%cst_max = llvm.mlir.constant(65504.0 : f16) : i16
%cst_min = llvm.mlir.constant(-65504.0 : f16) : i16
%v = vector.from_elements %cst0, %cst1, %cst_neg1, %cst_pi, %cst_max, %cst_min : vector<6xi16>
return %v : vector<6xi16>
}

// CHECK-LABEL: func @from_elements_f64_to_i64_conversion(
// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0, 4607182418800017408, -4616189618054758400, 4614253070214989087, 9218868437227405311, -4503599627370497]> : vector<6xi64>
// CHECK-NEXT: return %[[CST]] : vector<6xi64>
func.func @from_elements_f64_to_i64_conversion() -> vector<6xi64> {
%cst0 = llvm.mlir.constant(0.0 : f64) : i64
%cst1 = llvm.mlir.constant(1.0 : f64) : i64
%cst_neg1 = llvm.mlir.constant(-1.0 : f64) : i64
%cst_pi = llvm.mlir.constant(3.14 : f64) : i64
%cst_max = llvm.mlir.constant(1.7976931348623157e+308 : f64) : i64
%cst_min = llvm.mlir.constant(-1.7976931348623157e+308 : f64) : i64
%v = vector.from_elements %cst0, %cst1, %cst_neg1, %cst_pi, %cst_max, %cst_min : vector<6xi64>
return %v : vector<6xi64>
}

// -----

// CHECK-LABEL: func @from_elements_i1_to_i8_conversion(
// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<0> : vector<1xi8>
// CHECK-NEXT: return %[[CST]] : vector<1xi8>
func.func @from_elements_i1_to_i8_conversion() -> vector<1xi8> {
%cst = llvm.mlir.constant(0: i1) : i8
%v = vector.from_elements %cst : vector<1xi8>
return %v : vector<1xi8>
}

// -----

// CHECK-LABEL: func @from_elements_index_to_i64_conversion(
// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0, 1, 42]> : vector<3xi64>
// CHECK-NEXT: return %[[CST]] : vector<3xi64>
func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> {
%cst0 = llvm.mlir.constant(0 : index) : i64
%cst1 = llvm.mlir.constant(1 : index) : i64
%cst42 = llvm.mlir.constant(42 : index) : i64
%v = vector.from_elements %cst0, %cst1, %cst42 : vector<3xi64>
return %v : vector<3xi64>
}

// +---------------------------------------------------------------------------
// End of Tests for foldFromElementsToConstant
// +---------------------------------------------------------------------------
Expand Down