Skip to content

Commit 94ebcfd

Browse files
authored
[mlir][vector] Fix crash in ReorderCastOpsOnBroadcast with non-vector result (#170985)
Fixes a crash in `ReorderCastOpsOnBroadcast` by ensuring the cast result is a `VectorType` before applying the pattern. A regression test has been added to mlir/test/Dialect/Vector/vector-sink.mlir. Fixes: #126371
1 parent a033183 commit 94ebcfd

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ struct ReorderCastOpsOnBroadcast
453453
PatternRewriter &rewriter) const override {
454454
if (op->getNumOperands() != 1)
455455
return failure();
456+
if (!isa<VectorType>(op->getResult(0).getType()))
457+
return failure();
456458
auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
457459
if (!bcastOp)
458460
return failure();

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,21 @@ func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
382382
return %r : vector<2x[4]xi32>
383383
}
384384

385+
// -----
386+
387+
// CHECK-LABEL: func.func @negative_broadcast_cast_non_vector_result
388+
// CHECK-SAME: (%[[ARG:.*]]: i64)
389+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG]] : i64 to vector<26x7xi64>
390+
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[BCAST]] : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>>
391+
// CHECK: return %[[CAST]] : !llvm.array<26 x vector<7xi64>>
392+
/// This test ensures that the `ReorderCastOpsOnBroadcast` pattern does not
393+
/// attempt to reorder a cast operation that produces a non-vector result type.
394+
func.func @negative_broadcast_cast_non_vector_result(%arg0: i64) -> !llvm.array<26 x vector<7xi64>> {
395+
%0 = vector.broadcast %arg0 : i64 to vector<26x7xi64>
396+
%1 = builtin.unrealized_conversion_cast %0 : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>>
397+
return %1 : !llvm.array<26 x vector<7xi64>>
398+
}
399+
385400
//===----------------------------------------------------------------------===//
386401
// [Pattern: ReorderElementwiseOpsOnTranspose]
387402
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)