Skip to content

Commit 7627fce

Browse files
committed
vector handling
1 parent b79e3f3 commit 7627fce

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,8 +1092,14 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10921092
Value elem;
10931093

10941094
if (auto vecType = dyn_cast<VectorType>(originalResultType)) {
1095-
elem = rewriter.create<vector::ExtractStridedSliceOp>(
1096-
loc, newResult, offset, vecType.getNumElements(), 1);
1095+
assert(vecType.getRank() <= 1);
1096+
if (vecType.getRank() == 0) {
1097+
elem = rewriter.create<vector::ExtractOp>(loc, newResult, offset);
1098+
elem = rewriter.create<vector::SplatOp>(loc, vecType, elem);
1099+
} else {
1100+
elem = rewriter.create<vector::ExtractStridedSliceOp>(
1101+
loc, newResult, offset, vecType.getNumElements(), 1);
1102+
}
10971103
} else {
10981104
elem = rewriter.create<vector::ExtractOp>(loc, newResult, offset);
10991105
}

mlir/test/Dialect/Vector/slp-vectorize.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,62 @@ func.func @read_read_add_add_vec(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
699699
}
700700

701701

702+
// CHECK-LABEL: func @read_read_add_add_vec1
703+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
704+
func.func @read_read_add_add_vec1(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
705+
(vector<1xi32>, vector<1xi32>){
706+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
707+
// CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
708+
// CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32>
709+
// CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<2xi32>
710+
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
711+
// CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [1], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
712+
// CHECK: return %[[V3]], %[[V4]] : vector<1xi32>, vector<1xi32>
713+
%c0 = arith.constant 0 : index
714+
%c1 = arith.constant 1 : index
715+
716+
%0 = vector.load %arg0[%c0] : memref<8xi32>, vector<1xi32>
717+
%2 = vector.load %arg0[%c1] : memref<8xi32>, vector<1xi32>
718+
719+
%4 = vector.load %arg1[%c0] : memref<8xi32>, vector<1xi32>
720+
%6 = vector.load %arg1[%c1] : memref<8xi32>, vector<1xi32>
721+
722+
%8 = arith.addi %0, %4 : vector<1xi32>
723+
%10 = arith.addi %2, %6 : vector<1xi32>
724+
725+
return %8, %10 : vector<1xi32>, vector<1xi32>
726+
}
727+
728+
729+
// CHECK-LABEL: func @read_read_add_add_vec0d
730+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
731+
func.func @read_read_add_add_vec0d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
732+
(vector<i32>, vector<i32>){
733+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
734+
// CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
735+
// CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32>
736+
// CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<2xi32>
737+
// CHECK: %[[V3:.*]] = vector.extract %[[V2]][0] : i32 from vector<2xi32>
738+
// CHECK: %[[V4:.*]] = vector.splat %[[V3]] : vector<i32>
739+
// CHECK: %[[V5:.*]] = vector.extract %[[V2]][1] : i32 from vector<2xi32>
740+
// CHECK: %[[V6:.*]] = vector.splat %[[V5]] : vector<i32>
741+
// CHECK: return %[[V4]], %[[V6]] : vector<i32>, vector<i32>
742+
%c0 = arith.constant 0 : index
743+
%c1 = arith.constant 1 : index
744+
745+
%0 = vector.load %arg0[%c0] : memref<8xi32>, vector<i32>
746+
%2 = vector.load %arg0[%c1] : memref<8xi32>, vector<i32>
747+
748+
%4 = vector.load %arg1[%c0] : memref<8xi32>, vector<i32>
749+
%6 = vector.load %arg1[%c1] : memref<8xi32>, vector<i32>
750+
751+
%8 = arith.addi %0, %4 : vector<i32>
752+
%10 = arith.addi %2, %6 : vector<i32>
753+
754+
return %8, %10 : vector<i32>, vector<i32>
755+
}
756+
757+
702758
func.func private @use(i32)
703759

704760
// CHECK-LABEL: func @read_read_add_write_interleaved_use

0 commit comments

Comments
 (0)