@@ -398,6 +398,18 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
398398 return {};
399399}
400400
401+ // / Converts an IntegerAttr to have the specified type if needed.
402+ // / This handles cases where constant attributes have a different type than the
403+ // / target element type. If the input attribute is not an IntegerAttr or already
404+ // / has the correct type, returns it unchanged.
405+ static Attribute convertIntegerAttr (Attribute attr, Type expectedType) {
406+ if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
407+ if (intAttr.getType () != expectedType)
408+ return IntegerAttr::get (expectedType, intAttr.getInt ());
409+ }
410+ return attr;
411+ }
412+
401413// ===----------------------------------------------------------------------===//
402414// CombiningKindAttr
403415// ===----------------------------------------------------------------------===//
@@ -2464,8 +2476,37 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
24642476 return {};
24652477}
24662478
2479+ // / Fold vector.from_elements to a constant when all operands are constants.
2480+ // / Example:
2481+ // / %c1 = arith.constant 1 : i32
2482+ // / %c2 = arith.constant 2 : i32
2483+ // / %v = vector.from_elements %c1, %c2 : vector<2xi32>
2484+ // / =>
2485+ // / %v = arith.constant dense<[1, 2]> : vector<2xi32>
2486+ // /
2487+ static OpFoldResult foldFromElementsToConstant (FromElementsOp fromElementsOp,
2488+ ArrayRef<Attribute> elements) {
2489+ if (llvm::any_of (elements, [](Attribute attr) { return !attr; }))
2490+ return {};
2491+
2492+ auto destVecType = fromElementsOp.getDest ().getType ();
2493+ auto destEltType = destVecType.getElementType ();
2494+ // Constant attributes might have a different type than the return type.
2495+ // Convert them before creating the dense elements attribute.
2496+ auto convertedElements = llvm::map_to_vector (elements, [&](Attribute attr) {
2497+ return convertIntegerAttr (attr, destEltType);
2498+ });
2499+
2500+ return DenseElementsAttr::get (destVecType, convertedElements);
2501+ }
2502+
24672503OpFoldResult FromElementsOp::fold (FoldAdaptor adaptor) {
2468- return foldFromElementsToElements (*this );
2504+ if (auto res = foldFromElementsToElements (*this ))
2505+ return res;
2506+ if (auto res = foldFromElementsToConstant (*this , adaptor.getElements ()))
2507+ return res;
2508+
2509+ return {};
24692510}
24702511
24712512// / Rewrite a vector.from_elements into a vector.splat if all elements are the
@@ -3332,17 +3373,6 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
33323373
33333374 // / Converts the expected type to an IntegerAttr if there's
33343375 // / a mismatch.
3335- auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
3336- if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3337- if (intAttr.getType () != expectedType)
3338- return IntegerAttr::get (expectedType, intAttr.getInt ());
3339- }
3340- return attr;
3341- };
3342-
3343- // The `convertIntegerAttr` method specifically handles the case
3344- // for `llvm.mlir.constant` which can hold an attribute with a
3345- // different type than the return type.
33463376 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
33473377 for (auto value : denseSource.getValues <Attribute>())
33483378 insertedValues.push_back (convertIntegerAttr (value, destEltType));
0 commit comments