Skip to content

Commit 2090555

Browse files
authored
Add vector insert to vector linearize pass (#739)
1 parent 313caec commit 2090555

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

lib/Transforms/VectorLinearize.cpp

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,65 @@ struct VectorExtractOpConversion final
220220
}
221221
};
222222

223+
struct VectorInsertOpConversion final
224+
: public mlir::OpConversionPattern<mlir::vector::InsertOp> {
225+
using OpConversionPattern::OpConversionPattern;
226+
mlir::LogicalResult
227+
matchAndRewrite(mlir::vector::InsertOp insertOp, OpAdaptor adaptor,
228+
mlir::ConversionPatternRewriter &rewriter) const override {
229+
auto dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
230+
if (!dstTy)
231+
return rewriter.notifyMatchFailure(insertOp, "cannot convert type.");
232+
233+
// dynamic position is not supported
234+
if (insertOp.hasDynamicPosition())
235+
return rewriter.notifyMatchFailure(insertOp,
236+
"dynamic position is not supported.");
237+
auto srcTy = insertOp.getSourceType();
238+
auto src = insertOp.getSource();
239+
auto srcAsVec = mlir::dyn_cast<mlir::VectorType>(srcTy);
240+
uint64_t srcSize = 0;
241+
if (srcAsVec) {
242+
srcSize = srcAsVec.getNumElements();
243+
} else {
244+
return rewriter.notifyMatchFailure(insertOp,
245+
"scalars are not supported.");
246+
}
247+
248+
auto dst = insertOp.getDest();
249+
auto dstShape = insertOp.getDestVectorType().getShape();
250+
const auto dstSize = insertOp.getDestVectorType().getNumElements();
251+
auto dstSizeForOffsets = dstSize;
252+
253+
// compute linearized offset
254+
int64_t linearizedOffset = 0;
255+
auto offsetsNd = insertOp.getStaticPosition();
256+
for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
257+
dstSizeForOffsets /= dstShape[dim];
258+
linearizedOffset += offset * dstSizeForOffsets;
259+
}
260+
261+
llvm::SmallVector<int64_t, 2> indices(dstSize);
262+
auto origValsUntil = indices.begin();
263+
std::advance(origValsUntil, linearizedOffset);
264+
std::iota(indices.begin(), origValsUntil,
265+
0); // original values that remain [0, offset)
266+
auto newValsUntil = origValsUntil;
267+
std::advance(newValsUntil, srcSize);
268+
std::iota(origValsUntil, newValsUntil,
269+
dstSize); // new values [offset, offset+srcNumElements)
270+
std::iota(newValsUntil, indices.end(),
271+
linearizedOffset + srcSize); // the rest of original values
272+
// [offset+srcNumElements, end)
273+
274+
rewriter.replaceOpWithNewOp<mlir::vector::ShuffleOp>(
275+
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
276+
rewriter.getI64ArrayAttr(indices));
277+
278+
return mlir::success();
279+
}
280+
};
281+
223282
struct VectorLinearizePass final
224283
: public imex::impl::VectorLinearizeBase<VectorLinearizePass> {
225284

@@ -242,7 +301,8 @@ struct VectorLinearizePass final
242301
target.addLegalOp<mlir::vector::ShapeCastOp>();
243302

244303
patterns.add<VectorExtractStridedSliceConversion, VectorShffleOpConversion,
245-
VectorExtractOpConversion>(typeConverter, context);
304+
VectorExtractOpConversion, VectorInsertOpConversion>(
305+
typeConverter, context);
246306

247307
mlir::vector::populateVectorTransposeLoweringPatterns(
248308
patterns,

test/Transforms/vector-linearize.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,37 @@ func.func @test_vector_extract(%arg0: vector<2x8x4xf32>) -> vector<8x4xf32> {
100100
}
101101

102102
// -----
103+
// CHECK-LABEL: test_vector_insert
104+
// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32>
105+
// CHECK: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
106+
// CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
107+
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
108+
// CHECK: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
109+
// CHECK-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
110+
// CHECK-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
111+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
112+
// CHECK: return %[[RES]] : vector<2x8x4xf32>
113+
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
114+
%0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
115+
return %0 : vector<2x8x4xf32>
116+
}
103117

118+
// -----
119+
// CHECK-LABEL: test_vector_insert_2d_idx
120+
// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<4xf32>) -> vector<2x8x4xf32>
121+
// CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
122+
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[SRC]]
123+
// CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 64, 65, 66, 67, 16, 17, 18, 19, 20, 21,
124+
// CHECK-SAME: 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
125+
// CHECK-SAME: 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<4xf32>
126+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
127+
// CHECK: return %[[RES]] : vector<2x8x4xf32>
128+
func.func @test_vector_insert_2d_idx(%arg0: vector<2x8x4xf32>, %arg1: vector<4xf32>) -> vector<2x8x4xf32> {
129+
%0 = vector.insert %arg1, %arg0[0, 3]: vector<4xf32> into vector<2x8x4xf32>
130+
return %0 : vector<2x8x4xf32>
131+
}
132+
133+
// -----
104134
// CHECK-LABEL: test_vector_transpose
105135
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8xf32>) -> vector<8x2xf32>
106136
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8xf32> to vector<16xf32>

0 commit comments

Comments
 (0)