Skip to content

Commit 339450f

Browse files
authored
[mlir][vector] Fix crash in vector.from_elements folding with poison values (#158528)
The vector.from_elements constant folding was crashing when poison values were present in the element list. The convertIntegerAttr function was not properly handling poison attributes, leading to assertion failures in dyn_cast operations. This patch refactors convertIntegerAttr to take IntegerAttr directly, moving poison detection to the caller using explicit isa<ub::PoisonAttrInterface> checks. The function signature change provides compile-time type safety while the early poison validation in foldFromElementsToConstant prevents unsafe casting operations. The folding now gracefully aborts when poison attributes are encountered, preventing the crash while preserving correct folding for legitimate mixed-type constants (int/float). Fixes assertion: "dyn_cast on a non-existent value" when processing ub.poison values in vector.from_elements operations.
1 parent 3f52e97 commit 339450f

File tree

2 files changed

+67
-19
lines changed

2 files changed

+67
-19
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,13 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
397397
}
398398

399399
/// Converts an IntegerAttr to have the specified type if needed.
400-
/// This handles cases where constant attributes have a different type than the
401-
/// target element type. If the input attribute is not an IntegerAttr or already
402-
/// has the correct type, returns it unchanged.
403-
static Attribute convertIntegerAttr(Attribute attr, Type expectedType) {
404-
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
405-
if (intAttr.getType() != expectedType)
406-
return IntegerAttr::get(expectedType, intAttr.getInt());
407-
}
408-
return attr;
400+
/// This handles cases where integer constant attributes have a different type
401+
/// than the target element type.
402+
static IntegerAttr convertIntegerAttr(IntegerAttr intAttr, Type expectedType) {
403+
if (intAttr.getType() == expectedType)
404+
return intAttr; // Already correct type
405+
406+
return IntegerAttr::get(expectedType, intAttr.getInt());
409407
}
410408

411409
//===----------------------------------------------------------------------===//
@@ -2463,7 +2461,10 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
24632461
///
24642462
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
24652463
ArrayRef<Attribute> elements) {
2466-
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
2464+
// Check for null or poison attributes before any processing.
2465+
if (llvm::any_of(elements, [](Attribute attr) {
2466+
return !attr || isa<ub::PoisonAttrInterface>(attr);
2467+
}))
24672468
return {};
24682469

24692470
// DenseElementsAttr only supports int/index/float/complex types.
@@ -2472,11 +2473,16 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
24722473
if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
24732474
return {};
24742475

2475-
// Constant attributes might have a different type than the return type.
2476-
// Convert them before creating the dense elements attribute.
2477-
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2478-
return convertIntegerAttr(attr, destEltType);
2479-
});
2476+
// Convert integer attributes to the target type if needed, leave others
2477+
// unchanged.
2478+
auto convertedElements =
2479+
llvm::map_to_vector(elements, [&](Attribute attr) -> Attribute {
2480+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
2481+
return convertIntegerAttr(intAttr, destEltType);
2482+
}
2483+
return attr; // Non-integer attributes (FloatAttr, etc.) returned
2484+
// unchanged
2485+
});
24802486

24812487
return DenseElementsAttr::get(destVecType, convertedElements);
24822488
}
@@ -3497,13 +3503,19 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
34973503
SmallVector<Attribute> insertedValues;
34983504
Type destEltType = destTy.getElementType();
34993505

3500-
/// Converts the expected type to an IntegerAttr if there's
3501-
/// a mismatch.
3506+
/// Converts integer attributes to the expected type if there's a mismatch.
3507+
/// Non-integer attributes are left unchanged.
35023508
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
35033509
for (auto value : denseSource.getValues<Attribute>())
3504-
insertedValues.push_back(convertIntegerAttr(value, destEltType));
3510+
if (auto intAttr = dyn_cast<IntegerAttr>(value))
3511+
insertedValues.push_back(convertIntegerAttr(intAttr, destEltType));
3512+
else
3513+
insertedValues.push_back(value); // Non-integer attributes unchanged
35053514
} else {
3506-
insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
3515+
if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
3516+
insertedValues.push_back(convertIntegerAttr(intAttr, destEltType));
3517+
else
3518+
insertedValues.push_back(srcAttr); // Non-integer attributes unchanged
35073519
}
35083520

35093521
auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3375,6 +3375,42 @@ func.func @negative_from_elements_to_constant() -> vector<1x!llvm.ptr> {
33753375
return %b : vector<1x!llvm.ptr>
33763376
}
33773377

3378+
// -----
3379+
3380+
// CHECK-LABEL: @negative_from_elements_poison
3381+
// CHECK: %[[VAL:.*]] = ub.poison : vector<2xf32>
3382+
// CHECK: return %[[VAL]] : vector<2xf32>
3383+
func.func @negative_from_elements_poison_f32() -> vector<2xf32> {
3384+
%0 = ub.poison : f32
3385+
%1 = vector.from_elements %0, %0 : vector<2xf32>
3386+
return %1 : vector<2xf32>
3387+
}
3388+
3389+
// -----
3390+
3391+
// CHECK-LABEL: @negative_from_elements_poison_i32
3392+
// CHECK: %[[VAL:.*]] = ub.poison : vector<2xi32>
3393+
// CHECK: return %[[VAL]] : vector<2xi32>
3394+
func.func @negative_from_elements_poison_i32() -> vector<2xi32> {
3395+
%0 = ub.poison : i32
3396+
%1 = vector.from_elements %0, %0 : vector<2xi32>
3397+
return %1 : vector<2xi32>
3398+
}
3399+
3400+
// -----
3401+
3402+
// CHECK-LABEL: @negative_from_elements_poison_constant_mix
3403+
// CHECK: %[[POISON:.*]] = ub.poison : f32
3404+
// CHECK: %[[CONST:.*]] = arith.constant 1.000000e+00 : f32
3405+
// CHECK: %[[RES:.*]] = vector.from_elements %[[POISON]], %[[CONST]] : vector<2xf32>
3406+
// CHECK: return %[[RES]] : vector<2xf32>
3407+
func.func @negative_from_elements_poison_constant_mix() -> vector<2xf32> {
3408+
%0 = ub.poison : f32
3409+
%c = arith.constant 1.0 : f32
3410+
%1 = vector.from_elements %0, %c : vector<2xf32>
3411+
return %1 : vector<2xf32>
3412+
}
3413+
33783414
// +---------------------------------------------------------------------------
33793415
// End of Tests for foldFromElementsToConstant
33803416
// +---------------------------------------------------------------------------

0 commit comments

Comments
 (0)