Skip to content

Commit c56479d

Browse files
committed
[mlir][Vector] Fix vector.insert folder for scalar to 0-d inserts
1 parent 7fe149c commit c56479d

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2951,11 +2951,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
29512951
InsertOpConstantFolder>(context);
29522952
}
29532953

2954-
// Eliminates insert operations that produce values identical to their source
2955-
// value. This happens when the source and destination vectors have identical
2956-
// sizes.
29572954
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2958-
if (getNumIndices() == 0)
2955+
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
2956+
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
2957+
// (type mismatch).
2958+
if (getNumIndices() == 0 && getSourceType() == getResult().getType())
29592959
return getSource();
29602960
return {};
29612961
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2745,6 +2745,18 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
27452745

27462746
// -----
27472747

2748+
// CHECK-LABEL: func @insert_into_0d_regression(
2749+
// CHECK-SAME: %[[v:.*]]: vector<f32>)
2750+
// CHECK: %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32>
2751+
// CHECK: return %[[extract]]
2752+
func.func @insert_into_0d_regression(%v: vector<f32>) -> vector<f32> {
2753+
%cst = arith.constant 0.000000e+00 : f32
2754+
%0 = vector.insert %cst, %v [] : f32 into vector<f32>
2755+
return %0 : vector<f32>
2756+
}
2757+
2758+
// -----
2759+
27482760
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
27492761
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
27502762
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>

0 commit comments

Comments
 (0)