Skip to content

Commit 2b82f23

Browse files
committed
fix offset
1 parent ef51465 commit 2b82f23

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -965,13 +965,13 @@ SLPGraph::vectorize(IRRewriter &rewriter,
965965
}
966966
};
967967

968-
auto handleVecSizeMismatch = [&](Value arg) -> Value {
968+
auto handleVecSizeMismatch = [&](Value arg, int64_t offset = 0) -> Value {
969969
auto srcType = cast<VectorType>(arg.getType());
970970
assert(srcType.getRank() == 1);
971971
if (srcType.getDimSize(0) == numElements)
972972
return arg;
973973

974-
return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
974+
return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, offset,
975975
numElements, 1);
976976
};
977977

@@ -1007,7 +1007,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10071007
mapping.map(op->getResults(), newOp->getResults());
10081008
handleNonVectorOutputs(newOp->getResult(0));
10091009
} else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
1010-
Value val = handleVecSizeMismatch(extract.getVector());
1010+
// We alredy verified index is valid during graph construction.
1011+
int64_t offset = *getExtractIndex(extract);
1012+
Value val = handleVecSizeMismatch(extract.getVector(), offset);
10111013
mapping.map(extract.getResult(), val);
10121014
} else {
10131015
op->emitError("unsupported operation");

mlir/test/Dialect/Vector/slp-vectorize.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,9 @@ func.func @read_read_add_write_attrs_mismatch(%arg0: memref<8xi32>, %arg1: memre
380380
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
381381
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
382382
// CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
383-
// CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
383+
// CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
384384
// CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
385-
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
385+
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
386386
// CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
387387
// CHECK: %[[V5:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
388388
// CHECK: %[[V6:.*]] = arith.addi %[[V4]], %[[V5]] overflow<nsw> : vector<2xi32>
@@ -670,9 +670,9 @@ func.func @different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
670670
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
671671
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
672672
// CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
673-
// CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
673+
// CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
674674
// CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
675-
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
675+
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
676676
// CHECK: cf.br ^bb1
677677
// CHECK: ^bb1:
678678
// CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>

0 commit comments

Comments
 (0)