Skip to content

Commit 93b05dd

Browse files
[mlir][vector] Fix crashes in from_elements folder + broadcast verifier (#155393)
This PR fixes two crashes / failures. 1. The `vector.broadcast` verifier did not take into account `VectorElementTypeInterface` and was looking for int/index/float types. 2. The `vector.from_elements` folder attempted to create an invalid `DenseElementsAttr`. Only int/float/index/complex types are supported.
1 parent 1b6875e commit 93b05dd

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2466,8 +2466,12 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
24662466
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
24672467
return {};
24682468

2469+
// DenseElementsAttr only supports int/index/float/complex types.
24692470
auto destVecType = fromElementsOp.getDest().getType();
24702471
auto destEltType = destVecType.getElementType();
2472+
if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2473+
return {};
2474+
24712475
// Constant attributes might have a different type than the return type.
24722476
// Convert them before creating the dense elements attribute.
24732477
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
@@ -2778,8 +2782,8 @@ BroadcastableToResult mlir::vector::isBroadcastableTo(
27782782
Type srcType, VectorType dstVectorType,
27792783
std::pair<VectorDim, VectorDim> *mismatchingDims) {
27802784
// Broadcast scalar to vector of the same element type.
2781-
if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2782-
getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2785+
if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2786+
srcType == getElementTypeOrSelf(dstVectorType))
27832787
return BroadcastableToResult::Success;
27842788
// From now on, only vectors broadcast.
27852789
VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3726,3 +3726,17 @@ func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32
37263726
%v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
37273727
return %v_1 : vector<4xf32>
37283728
}
3729+
3730+
// -----
3731+
3732+
llvm.mlir.global constant @my_symbol() : i32
3733+
3734+
// CHECK-LABEL: func @from_address_of_regression
3735+
// CHECK: %[[a:.*]] = llvm.mlir.addressof @my_symbol
3736+
// CHECK: %[[b:.*]] = vector.broadcast %[[a]] : !llvm.ptr to vector<1x!llvm.ptr>
3737+
// CHECK: return %[[b]]
3738+
func.func @from_address_of_regression() -> vector<1x!llvm.ptr> {
3739+
%a = llvm.mlir.addressof @my_symbol : !llvm.ptr
3740+
%b = vector.from_elements %a : vector<1x!llvm.ptr>
3741+
return %b : vector<1x!llvm.ptr>
3742+
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
149149
}
150150

151151
// CHECK-LABEL: @vector_broadcast
152-
func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>, %f: vector<8x1x!llvm.ptr<1>>) {
152+
func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>, %f: vector<8x1x!llvm.ptr<1>>, %g: !llvm.ptr<1>) {
153153
// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
154154
%0 = vector.broadcast %a : f32 to vector<f32>
155155
// CHECK: vector.broadcast %{{.*}} : vector<f32> to vector<4xf32>
@@ -164,6 +164,8 @@ func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: ve
164164
%5 = vector.broadcast %e : vector<8x1xf32> to vector<8x16xf32>
165165
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1x!llvm.ptr<1>> to vector<8x16x!llvm.ptr<1>>
166166
%6 = vector.broadcast %f : vector<8x1x!llvm.ptr<1>> to vector<8x16x!llvm.ptr<1>>
167+
// CHECK-NEXT: vector.broadcast %{{.*}} : !llvm.ptr<1> to vector<8x16x!llvm.ptr<1>>
168+
%7 = vector.broadcast %g : !llvm.ptr<1> to vector<8x16x!llvm.ptr<1>>
167169
return
168170
}
169171

0 commit comments

Comments
 (0)