Skip to content

Commit ddd6040

Browse files
committed
1 parent a368667 commit ddd6040

File tree

1 file changed

+28
-29
lines changed

1 file changed

+28
-29
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -397,20 +397,11 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
397397
}
398398

399399
/// Converts an IntegerAttr to have the specified type if needed.
400-
/// This handles cases where constant attributes have a different type than the
401-
/// target element type. Returns null if the attribute is poison/invalid or
402-
/// conversion fails.
403-
static Attribute convertIntegerAttr(Attribute attr, Type expectedType) {
404-
// Check for poison attributes before any casting operations
405-
if (!attr || isa<ub::PoisonAttrInterface>(attr))
406-
return {}; // Poison or invalid attribute
407-
408-
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr);
409-
if (!intAttr)
410-
return attr; // Not an IntegerAttr, return unchanged (e.g., FloatAttr)
411-
400+
/// This handles cases where integer constant attributes have a different type
401+
/// than the target element type.
402+
static IntegerAttr convertIntegerAttr(IntegerAttr intAttr, Type expectedType) {
412403
if (intAttr.getType() == expectedType)
413-
return attr; // Already correct type
404+
return intAttr; // Already correct type
414405

415406
return IntegerAttr::get(expectedType, intAttr.getInt());
416407
}
@@ -2470,7 +2461,10 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
24702461
///
24712462
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
24722463
ArrayRef<Attribute> elements) {
2473-
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
2464+
// Check for null or poison attributes before any processing.
2465+
if (llvm::any_of(elements, [](Attribute attr) {
2466+
return !attr || isa<ub::PoisonAttrInterface>(attr);
2467+
}))
24742468
return {};
24752469

24762470
// DenseElementsAttr only supports int/index/float/complex types.
@@ -2479,18 +2473,14 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
24792473
if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
24802474
return {};
24812475

2482-
// Constant attributes might have a different type than the return type.
2483-
// Convert them before creating the dense elements attribute.
2484-
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2485-
return convertIntegerAttr(attr, destEltType);
2476+
// Convert integer attributes to the target type if needed, leave others unchanged.
2477+
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) -> Attribute {
2478+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
2479+
return convertIntegerAttr(intAttr, destEltType);
2480+
}
2481+
return attr; // Non-integer attributes (FloatAttr, etc.) returned unchanged
24862482
});
24872483

2488-
// Check if any attributes are poison/invalid (indicated by null attributes).
2489-
// Note: convertIntegerAttr returns valid non-integer attributes unchanged,
2490-
// only returns null for poison/invalid attributes.
2491-
if (llvm::any_of(convertedElements, [](Attribute attr) { return !attr; }))
2492-
return {};
2493-
24942484
return DenseElementsAttr::get(destVecType, convertedElements);
24952485
}
24962486

@@ -3510,13 +3500,22 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
35103500
SmallVector<Attribute> insertedValues;
35113501
Type destEltType = destTy.getElementType();
35123502

3513-
/// Converts the expected type to an IntegerAttr if there's
3514-
/// a mismatch.
3503+
/// Converts integer attributes to the expected type if there's a mismatch.
3504+
/// Non-integer attributes are left unchanged.
35153505
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3516-
for (auto value : denseSource.getValues<Attribute>())
3517-
insertedValues.push_back(convertIntegerAttr(value, destEltType));
3506+
for (auto value : denseSource.getValues<Attribute>()) {
3507+
if (auto intAttr = dyn_cast<IntegerAttr>(value)) {
3508+
insertedValues.push_back(convertIntegerAttr(intAttr, destEltType));
3509+
} else {
3510+
insertedValues.push_back(value); // Non-integer attributes unchanged
3511+
}
3512+
}
35183513
} else {
3519-
insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
3514+
if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr)) {
3515+
insertedValues.push_back(convertIntegerAttr(intAttr, destEltType));
3516+
} else {
3517+
insertedValues.push_back(srcAttr); // Non-integer attributes unchanged
3518+
}
35203519
}
35213520

35223521
auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());

0 commit comments

Comments
 (0)