diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 347141e2773b8..926d97bf6e38f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -396,14 +396,31 @@ std::optional vector::getConstantVscaleMultiplier(Value value) { return {}; } -/// Converts an IntegerAttr to have the specified type if needed. -/// This handles cases where integer constant attributes have a different type -/// than the target element type. -static IntegerAttr convertIntegerAttr(IntegerAttr intAttr, Type expectedType) { - if (intAttr.getType() == expectedType) - return intAttr; // Already correct type +/// Converts numeric attributes to the expected type. Supports +/// integer-to-integer and float-to-integer conversions. Returns the original +/// attribute if no conversion is needed or supported. +static Attribute convertNumericAttr(Attribute attr, Type expectedType) { + // Integer-to-integer conversion + if (auto intAttr = dyn_cast(attr)) { + if (auto intType = dyn_cast(expectedType)) { + if (intAttr.getType() != expectedType) + return IntegerAttr::get(expectedType, intAttr.getInt()); + } + return attr; + } + + // Float-to-integer bitcast (preserves bit representation) + if (auto floatAttr = dyn_cast(attr)) { + auto intType = dyn_cast(expectedType); + if (!intType) + return attr; + + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return IntegerAttr::get(expectedType, intVal); + } - return IntegerAttr::get(expectedType, intAttr.getInt()); + return attr; } //===----------------------------------------------------------------------===// @@ -2473,16 +2490,11 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, if (!destEltType.isIntOrIndexOrFloat() && !isa(destEltType)) return {}; - // Convert integer attributes to the target type if needed, leave others - // unchanged. - auto convertedElements = - llvm::map_to_vector(elements, [&](Attribute attr) -> Attribute { - if (auto intAttr = dyn_cast(attr)) { - return convertIntegerAttr(intAttr, destEltType); - } - return attr; // Non-integer attributes (FloatAttr, etc.) returned - // unchanged - }); + // Constant attributes might have a different type than the return type. + // Convert them before creating the dense elements attribute. + auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) { + return convertNumericAttr(attr, destEltType); + }); return DenseElementsAttr::get(destVecType, convertedElements); } @@ -3503,19 +3515,13 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, SmallVector insertedValues; Type destEltType = destTy.getElementType(); - /// Converts integer attributes to the expected type if there's a mismatch. - /// Non-integer attributes are left unchanged. + /// Converts attribute to the expected type if there's + /// a mismatch. if (auto denseSource = llvm::dyn_cast(srcAttr)) { for (auto value : denseSource.getValues()) - if (auto intAttr = dyn_cast(value)) - insertedValues.push_back(convertIntegerAttr(intAttr, destEltType)); - else - insertedValues.push_back(value); // Non-integer attributes unchanged + insertedValues.push_back(convertNumericAttr(value, destEltType)); } else { - if (auto intAttr = dyn_cast(srcAttr)) - insertedValues.push_back(convertIntegerAttr(intAttr, destEltType)); - else - insertedValues.push_back(srcAttr); // Non-integer attributes unchanged + insertedValues.push_back(convertNumericAttr(srcAttr, destEltType)); } auto allValues = llvm::to_vector(denseDst.getValues()); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 75c762f38432a..5448976f84760 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3411,6 +3411,74 @@ func.func @negative_from_elements_poison_constant_mix() -> vector<2xf32> { return %1 : vector<2xf32> } +// ----- + +// CHECK-LABEL: func @from_elements_float8_to_i8_conversion( +// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0, 56, -72, 69, 127, -1]> : vector<6xi8> +// CHECK-NEXT: return %[[CST]] : vector<6xi8> +func.func @from_elements_float8_to_i8_conversion() -> vector<6xi8> { + %cst0 = llvm.mlir.constant(0.0 : f8E4M3FN) : i8 + %cst1 = llvm.mlir.constant(1.0 : f8E4M3FN) : i8 + %cst_neg1 = llvm.mlir.constant(-1.0 : f8E4M3FN) : i8 + %cst_pi = llvm.mlir.constant(3.14 : f8E4M3FN) : i8 + %cst_inf = llvm.mlir.constant(0x7F : f8E4M3FN) : i8 + %cst_neg_inf = llvm.mlir.constant(0xFF : f8E4M3FN) : i8 + %v = vector.from_elements %cst0, %cst1, %cst_neg1, %cst_pi, %cst_inf, %cst_neg_inf : vector<6xi8> + return %v : vector<6xi8> +} + +// CHECK-LABEL: func @from_elements_float16_to_i16_conversion( +// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0, 15360, -17408, 16968, 31743, -1025]> : vector<6xi16> +// CHECK-NEXT: return %[[CST]] : vector<6xi16> +func.func @from_elements_float16_to_i16_conversion() -> vector<6xi16> { + %cst0 = llvm.mlir.constant(0.0 : f16) : i16 + %cst1 = llvm.mlir.constant(1.0 : f16) : i16 + %cst_neg1 = llvm.mlir.constant(-1.0 : f16) : i16 + %cst_pi = llvm.mlir.constant(3.14 : f16) : i16 + %cst_max = llvm.mlir.constant(65504.0 : f16) : i16 + %cst_min = llvm.mlir.constant(-65504.0 : f16) : i16 + %v = vector.from_elements %cst0, %cst1, %cst_neg1, %cst_pi, %cst_max, %cst_min : vector<6xi16> + return %v : vector<6xi16> +} + +// CHECK-LABEL: func @from_elements_f64_to_i64_conversion( +// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0, 4607182418800017408, -4616189618054758400, 4614253070214989087, 9218868437227405311, -4503599627370497]> : vector<6xi64> +// CHECK-NEXT: return %[[CST]] : vector<6xi64> +func.func @from_elements_f64_to_i64_conversion() -> vector<6xi64> { + %cst0 = llvm.mlir.constant(0.0 : f64) : i64 + %cst1 = llvm.mlir.constant(1.0 : f64) : i64 + %cst_neg1 = llvm.mlir.constant(-1.0 : f64) : i64 + %cst_pi = llvm.mlir.constant(3.14 : f64) : i64 + %cst_max = llvm.mlir.constant(1.7976931348623157e+308 : f64) : i64 + %cst_min = llvm.mlir.constant(-1.7976931348623157e+308 : f64) : i64 + %v = vector.from_elements %cst0, %cst1, %cst_neg1, %cst_pi, %cst_max, %cst_min : vector<6xi64> + return %v : vector<6xi64> +} + +// ----- + +// CHECK-LABEL: func @from_elements_i1_to_i8_conversion( +// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<0> : vector<1xi8> +// CHECK-NEXT: return %[[CST]] : vector<1xi8> +func.func @from_elements_i1_to_i8_conversion() -> vector<1xi8> { + %cst = llvm.mlir.constant(0: i1) : i8 + %v = vector.from_elements %cst : vector<1xi8> + return %v : vector<1xi8> +} + +// ----- + +// CHECK-LABEL: func @from_elements_index_to_i64_conversion( +// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<[0, 1, 42]> : vector<3xi64> +// CHECK-NEXT: return %[[CST]] : vector<3xi64> +func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> { + %cst0 = llvm.mlir.constant(0 : index) : i64 + %cst1 = llvm.mlir.constant(1 : index) : i64 + %cst42 = llvm.mlir.constant(42 : index) : i64 + %v = vector.from_elements %cst0, %cst1, %cst42 : vector<3xi64> + return %v : vector<3xi64> +} + // +--------------------------------------------------------------------------- // End of Tests for foldFromElementsToConstant // +---------------------------------------------------------------------------