@@ -96,3 +96,47 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
9696 %0 = vector.transpose %arg0 , [1 , 0 ] : vector <2 x3 xi32 > to vector <3 x2 xi32 >
9797 return %0 : vector <3 x2 xi32 >
9898}
99+
100+ // -----
101+
102+ // In order to verify that the pattern is applied,
103+ // we need to make sure that the the 2d vector does not
104+ // come from the parameters. Otherwise, the pattern
105+ // in unrollVectorsInSignatures which splits the 2d vector
106+ // parameter will take precedent. Similarly, let's avoid
107+ // returning a vector as another pattern would take precendence.
108+
109+ // CHECK-LABEL: @unroll_to_elements_2d
110+ func.func @unroll_to_elements_2d () -> (f32 , f32 , f32 , f32 ) {
111+ %1 = " test.op" () : () -> (vector <2 x2 xf32 >)
112+ // CHECK: %[[VEC2D:.+]] = "test.op"
113+ // CHECK: %[[VEC0:.+]] = vector.extract %[[VEC2D]][0] : vector<2xf32> from vector<2x2xf32>
114+ // CHECK: %[[VEC1:.+]] = vector.extract %[[VEC2D]][1] : vector<2xf32> from vector<2x2xf32>
115+ // CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]]
116+ // CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]]
117+ %2:4 = vector.to_elements %1 : vector <2 x2 xf32 >
118+ return %2#0 , %2#1 , %2#2 , %2#3 : f32 , f32 , f32 , f32
119+ }
120+
121+ // -----
122+
123+ // In order to verify that the pattern is applied,
124+ // we need to make sure that the the 2d vector is used
125+ // by an operation and that extracts are not folded away.
126+ // In other words we can't use "test.op" nor return the
127+ // value `%0 = vector.from_elements`
128+
129+ // CHECK-LABEL: @unroll_from_elements_2d
130+ // CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32)
131+ func.func @unroll_from_elements_2d (%arg0: f32 , %arg1: f32 , %arg2: f32 , %arg3: f32 ) -> (vector <2 x2 xf32 >) {
132+ // CHECK: %[[VEC0:.+]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
133+ // CHECK: %[[VEC1:.+]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
134+ %0 = vector.from_elements %arg0 , %arg1 , %arg2 , %arg3 : vector <2 x2 xf32 >
135+
136+ // CHECK: %[[RES0:.+]] = arith.addf %[[VEC0]], %[[VEC0]]
137+ // CHECK: %[[RES1:.+]] = arith.addf %[[VEC1]], %[[VEC1]]
138+ %1 = arith.addf %0 , %0 : vector <2 x2 xf32 >
139+
140+ // return %[[RES0]], %%[[RES1]] : vector<2xf32>, vector<2xf32>
141+ return %1 : vector <2 x2 xf32 >
142+ }
0 commit comments