Skip to content

Commit 298fff7

Browse files
committed
[mlir][Vector] Fix scalable Insert/ExtractSlice lowering
It looks like scalable `vector.insert/extractslice` ops made their way through lowering patterns that generate `vector.shuffle`` ops. I'm not sure why this wasn't caught by the verifier, probably because the shuffle op was folded into something else as part of the same rewrite and the IR wasn't verified. This PR fixes the issue by preventing scalable vector.insert/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insert/extractelement ops using an existing patter.
1 parent 213a939 commit 298fff7

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,15 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
9696
PatternRewriter &rewriter) const override {
9797
auto srcType = op.getSourceVectorType();
9898
auto dstType = op.getDestVectorType();
99+
int64_t srcRank = srcType.getRank();
100+
101+
// Scalable vectors are not supported by vector shuffle.
102+
if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
103+
return failure();
99104

100105
if (op.getOffsets().getValue().empty())
101106
return failure();
102107

103-
int64_t srcRank = srcType.getRank();
104108
int64_t dstRank = dstType.getRank();
105109
assert(dstRank >= srcRank);
106110
if (dstRank != srcRank)
@@ -184,6 +188,11 @@ class Convert1DExtractStridedSliceIntoShuffle
184188
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
185189
PatternRewriter &rewriter) const override {
186190
auto dstType = op.getType();
191+
auto srcType = op.getSourceVectorType();
192+
193+
// Scalable vectors are not supported by vector shuffle.
194+
if (dstType.isScalable() || srcType.isScalable())
195+
return failure();
187196

188197
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
189198

@@ -331,4 +340,14 @@ void vector::populateVectorInsertExtractStridedSliceTransforms(
331340
patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
332341
Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
333342
benefit);
343+
// Generate chains of extract/insert ops for scalable vectors only as they
344+
// can't be lowered to vector shuffles.
345+
populateVectorExtractStridedSliceToExtractInsertChainPatterns(
346+
patterns,
347+
/*controlFn=*/
348+
[](ExtractStridedSliceOp op) {
349+
return op.getType().isScalable() ||
350+
op.getSourceVectorType().isScalable();
351+
},
352+
benefit);
334353
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,13 +2026,14 @@ func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32
20262026
// CHECK-LABEL: func.func @extract_strided_slice_f32_1d_from_2d_scalable(
20272027
// CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>)
20282028
// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
2029-
// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
2030-
// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
2031-
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
2032-
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
2033-
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
2034-
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
2035-
// CHECK: return %[[T5]]
2029+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
2030+
// CHECK: %[[DST:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
2031+
// CHECK: %[[E0:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
2032+
// CHECK: %[[E1:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
2033+
// CHECK: %[[I0:.*]] = llvm.insertvalue %[[E0]], %[[DST]][0] : !llvm.array<2 x vector<[8]xf32>>
2034+
// CHECK: %[[I1:.*]] = llvm.insertvalue %[[E1]], %[[I0]][1] : !llvm.array<2 x vector<[8]xf32>>
2035+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[I1]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
2036+
// CHECK: return %[[RES]]
20362037

20372038
// -----
20382039

0 commit comments

Comments
 (0)