Skip to content

Commit bcc4391

Browse files
committed
address some of reviewer's comments
1 parent cd894d8 commit bcc4391

File tree

2 files changed

+53
-20
lines changed

2 files changed

+53
-20
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,10 @@ static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
143143
/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
144144
/// positions written to are (1,3) and (1,4), which have linearized indices 8
145145
/// and 9. So [8,9] is returned.
146-
SmallVector<int64_t> static getFlattenedStridedSliceIndices(
146+
///
147+
/// The length of the returned vector is equal to the number of elements in
148+
/// the shape `small` (i.e. the product of dimensions of `small`).
149+
SmallVector<int64_t> static getStridedSliceInsertionIndices(
147150
ArrayRef<int64_t> small, ArrayRef<int64_t> large,
148151
ArrayRef<int64_t> offsets) {
149152

@@ -153,8 +156,10 @@ SmallVector<int64_t> static getFlattenedStridedSliceIndices(
153156
// offsets = 2, 3, 0
154157
//
155158
// `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
156-
assert(large.size() >= small.size());
157-
assert(large.size() >= offsets.size());
159+
assert((large.size() >= small.size()) &&
160+
"rank of 'large' cannot be lower than rank of 'small'");
161+
assert((large.size() >= offsets.size()) &&
162+
"rank of 'large' cannot be lower than the number of offsets");
158163
unsigned delta = large.size() - small.size();
159164
unsigned nOffsets = offsets.size();
160165
auto getSmall = [&](int64_t i) { return i >= delta ? small[i - delta] : 1; };
@@ -223,10 +228,8 @@ struct LinearizeVectorExtractStridedSlice final
223228
extractStridedSliceOp.getType());
224229
assert(flatOutputType && "vector type expected");
225230

226-
if (!stridesAllOne(extractStridedSliceOp)) {
227-
return rewriter.notifyMatchFailure(extractStridedSliceOp,
228-
"strides other than 1 not supported");
229-
}
231+
assert(stridesAllOne(extractStridedSliceOp) &&
232+
"has extract_strided_slice's verifier not checked strides are 1?");
230233

231234
FailureOr<SmallVector<int64_t>> offsets =
232235
intsFromArrayAttr(extractStridedSliceOp.getOffsets());
@@ -240,7 +243,7 @@ struct LinearizeVectorExtractStridedSlice final
240243

241244
ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
242245

243-
SmallVector<int64_t> indices = getFlattenedStridedSliceIndices(
246+
SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
244247
outputShape, inputShape, offsets.value());
245248

246249
Value srcVector = adaptor.getVector();
@@ -287,10 +290,9 @@ struct LinearizeVectorInsertStridedSlice final
287290
OpAdaptor adaptor,
288291
ConversionPatternRewriter &rewriter) const override {
289292

290-
if (!stridesAllOne(insertStridedSliceOp)) {
291-
return rewriter.notifyMatchFailure(insertStridedSliceOp,
292-
"strides other than 1 not supported");
293-
}
293+
// See InsertStridedSliceOp's verify method.
294+
assert(stridesAllOne(insertStridedSliceOp) &&
295+
"has insert_strided_slice's verifier not checked strides are 1?");
294296

295297
VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
296298
ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -305,10 +307,10 @@ struct LinearizeVectorInsertStridedSlice final
305307
return rewriter.notifyMatchFailure(insertStridedSliceOp,
306308
"failed to get integer offsets");
307309
}
308-
SmallVector<int64_t> sliceIndices = getFlattenedStridedSliceIndices(
310+
SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
309311
inputShape, outputShape, offsets.value());
310312

311-
SmallVector<int64_t> indices(nOutputElements, 0);
313+
SmallVector<int64_t> indices(nOutputElements);
312314
std::iota(indices.begin(), indices.end(), 0);
313315
for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
314316
indices[sliceIndex] = index + nOutputElements;

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s
1+
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s
22

33
// CHECK-LABEL: test_linearize
44
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -153,7 +153,7 @@ func.func @test_extract_strided_slice_2D_scalable(%arg0: vector<4x[8]xf32>) -> v
153153

154154
// CHECK-NOT: vector.shuffle
155155
// CHECK-NOT: vector.shape_cast
156-
// CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]]
156+
// CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]]
157157
%0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32>
158158

159159
// CHECK: return %[[RES]] : vector<2x[8]xf32>
@@ -179,7 +179,7 @@ func.func @test_extract_strided_slice_3D(%arg0 : vector<2x8x2xf32>) -> vector<1x
179179
// -----
180180

181181
// Test of insert_strided_slice -> shuffle.
182-
// This is a contiguous insertion of 4 elements at offset 6 into a vector of 12 elements.
182+
// This is a contiguous insertion of 4 elements at offset 6 into a vector of 12 elements.
183183
// CHECK-LABEL: insert_strided_slice_2D_into_4D
184184
func.func @insert_strided_slice_2D_into_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1x3x2xi8>) -> vector<2x1x3x2xi8> {
185185

@@ -196,19 +196,19 @@ func.func @insert_strided_slice_2D_into_4D(%arg0 : vector<2x2xi8>, %arg1 : vecto
196196

197197
// -----
198198

199-
// Test of insert_strided_slice -> shuffle.
199+
// Test of insert_strided_slice -> shuffle.
200200
// [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15]], [[16, 17]]]
201201
// ^ ^
202202
// | |
203203
// where the 2 elements are inserted into the 3x3x2 vector
204204
// CHECK-LABEL: insert_strided_slice_3D
205-
func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg2 : vector<3x3x2xi8>) -> vector<3x3x2xi8> {
205+
func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg1 : vector<3x3x2xi8>) -> vector<3x3x2xi8> {
206206

207207
// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<2xi8>
208208
// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<18xi8>
209209
// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]]
210210
// CHECK-SAME: [0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 10, 19, 12, 13, 14, 15, 16, 17] : vector<18xi8>, vector<2xi8>
211-
%0 = vector.insert_strided_slice %arg0, %arg2 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8>
211+
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8>
212212

213213
// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<3x3x2xi8>
214214
// CHECK: return %[[RES]] : vector<3x3x2xi8>
@@ -217,6 +217,37 @@ func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg2 : vector<3x3x
217217

218218
// -----
219219

220+
// CHECK-LABEL: insert_strided_slice_2D_higher_offsets
221+
func.func @insert_strided_slice_2D_higher_offsets(%arg0 : vector<2x1xi8>, %arg1 : vector<2x2xi8>, %arg2 : vector<5x2xi8>) -> vector<5x2xi8> {
222+
223+
// CHECK: [0, 1, 2, 3, 10, 11, 12, 13, 8, 9]
224+
// ^^^ ^^^ ^^^ ^^^
225+
// insertion indices
226+
%0 = vector.insert_strided_slice %arg1, %arg2 {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x2xi8> into vector<5x2xi8>
227+
228+
// CHECK: [0, 1, 2, 3, 10, 5, 11, 7, 8, 9]
229+
// ^^^ ^^^
230+
%1 = vector.insert_strided_slice %arg0, %0 {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<2x1xi8> into vector<5x2xi8>
231+
232+
// CHECK: [0, 1, 2, 3, 4, 5, 6, 10, 8, 11]
233+
// ^^^ ^^^
234+
%2 = vector.insert_strided_slice %arg0, %1 {offsets = [3, 1], sizes = [2, 1], strides = [1, 1]} : vector<2x1xi8> into vector<5x2xi8>
235+
236+
return %2 : vector<5x2xi8>
237+
}
238+
239+
// -----
240+
241+
// CHECK-LABEL: negative_insert_strided_slice_scalable
242+
// CHECK-NOT: vector.shuffle
243+
// CHECK: return
244+
func.func @negative_insert_strided_slice_scalable(%arg0 : vector<1x[2]xi8>, %arg1 : vector<2x[2]xi8>) -> vector<2x[2]xi8> {
245+
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0], strides = [1,1]} : vector<1x[2]xi8> into vector<2x[2]xi8>
246+
return %0 : vector<2x[2]xi8>
247+
}
248+
249+
// -----
250+
220251
// CHECK-LABEL: test_vector_shuffle
221252
// CHECK-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
222253
func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {

0 commit comments

Comments
 (0)