Skip to content

Commit 08655a9

Browse files
committed
[mlir] Add vector.{to_elements,from_elements} unrolling to VectorToSPIRV
1 parent 488ce6b commit 08655a9

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
14951495
RewritePatternSet patterns(context);
14961496
auto options = vector::UnrollVectorOptions().setNativeShapeFn(
14971497
[](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1498+
vector::populateVectorFromElementsLoweringPatterns(patterns);
1499+
vector::populateVectorToElementsLoweringPatterns(patterns);
14981500
populateVectorUnrollPatterns(patterns, options);
14991501
if (failed(applyPatternsGreedily(op, std::move(patterns))))
15001502
return failure();

mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,47 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
9696
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
9797
return %0 : vector<3x2xi32>
9898
}
99+
100+
// -----
101+
102+
// In order to verify that the pattern is applied,
103+
// we need to make sure that the the 2d vector does not
104+
// come from the parameters. Otherwise, the pattern
105+
// in unrollVectorsInSignatures which splits the 2d vector
106+
// parameter will take precedent. Similarly, let's avoid
107+
// returning a vector as another pattern would take precendence.
108+
109+
// CHECK-LABEL: @unroll_to_elements_2d
110+
func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) {
111+
%1 = "test.op"() : () -> (vector<2x2xf32>)
112+
// CHECK: %[[VEC2D:.+]] = "test.op"
113+
// CHECK: %[[VEC0:.+]] = vector.extract %[[VEC2D]][0] : vector<2xf32> from vector<2x2xf32>
114+
// CHECK: %[[VEC1:.+]] = vector.extract %[[VEC2D]][1] : vector<2xf32> from vector<2x2xf32>
115+
// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]]
116+
// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]]
117+
%2:4 = vector.to_elements %1 : vector<2x2xf32>
118+
return %2#0, %2#1, %2#2, %2#3 : f32, f32, f32, f32
119+
}
120+
121+
// -----
122+
123+
// In order to verify that the pattern is applied,
124+
// we need to make sure that the the 2d vector is used
125+
// by an operation and that extracts are not folded away.
126+
// In other words we can't use "test.op" nor return the
127+
// value `%0 = vector.from_elements`
128+
129+
// CHECK-LABEL: @unroll_from_elements_2d
130+
// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32)
131+
func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> (vector<2x2xf32>) {
132+
// CHECK: %[[VEC0:.+]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
133+
// CHECK: %[[VEC1:.+]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
134+
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
135+
136+
// CHECK: %[[RES0:.+]] = arith.addf %[[VEC0]], %[[VEC0]]
137+
// CHECK: %[[RES1:.+]] = arith.addf %[[VEC1]], %[[VEC1]]
138+
%1 = arith.addf %0, %0 : vector<2x2xf32>
139+
140+
// return %[[RES0]], %%[[RES1]] : vector<2xf32>, vector<2xf32>
141+
return %1 : vector<2x2xf32>
142+
}

0 commit comments

Comments
 (0)