Skip to content

Commit dc26c03

Browse files
committed
[mlir][vector] Add insertOp src shape check for BubbleUpBitCastForStridedSliceInsert
Not all shape of vectors can be casted into other types, we add a check to not fold insertOp into bitcast if the shape does not support it. Examples of unsupported shape castings are f16 vectors to f32 if the shape is not multiple of 2s. or int8 to int32 if shapes are not multiple of 4. Reviewed By: antiagainst, ThomasRaoux Differential Revision: https://reviews.llvm.org/D137802
1 parent 97105e5 commit dc26c03

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2503,6 +2503,14 @@ struct BubbleUpBitCastForStridedSliceInsert
25032503
if (rank != insertOp.getDestVectorType().getRank())
25042504
return failure();
25052505

2506+
// Requires that shape of insert op src is castable to dstType.
2507+
unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
2508+
unsigned destinationWidth =
2509+
castDstType.getElementType().getIntOrFloatBitWidth();
2510+
unsigned numElements = destinationWidth / sourceWidth;
2511+
if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
2512+
return failure();
2513+
25062514
ArrayAttr newOffsets = insertOp.getOffsets();
25072515
assert(newOffsets.size() == rank);
25082516
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);

mlir/test/Dialect/Vector/vector-transforms.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,3 +507,21 @@ func.func @bubble_up_bitcast_in_strided_slice_insert_different_rank(%dst: vector
507507
%cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32>
508508
return %cast: vector<16x4x4xf32>
509509
}
510+
511+
// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_shape
512+
func.func @bubble_up_bitcast_in_strided_slice_insert_odd_shape(%dst: vector<2xf16>, %src: vector<1xf16>) -> vector<1xf32> {
513+
// CHECK: vector.insert_strided_slice
514+
// CHECK-NEXT: vector.bitcast
515+
%0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16>
516+
%cast = vector.bitcast %0: vector<2xf16> to vector<1xf32>
517+
return %cast: vector<1xf32>
518+
}
519+
520+
// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape
521+
func.func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape(%dst: vector<8xf16>, %src: vector<3xf16>) -> vector<4xf32> {
522+
// CHECK: vector.insert_strided_slice
523+
// CHECK-NEXT: vector.bitcast
524+
%0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<3xf16> into vector<8xf16>
525+
%cast = vector.bitcast %0: vector<8xf16> to vector<4xf32>
526+
return %cast: vector<4xf32>
527+
}

0 commit comments

Comments
 (0)