@@ -188,15 +188,38 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
188188// CHECK-LABEL: func @vector_fma
189189// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
190190
191- // TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
192- func.func @negative_vector_fma_3d (%a: vector <3 x2 x2 xf32 >) -> vector <3 x2 x2 xf32 >{
191+ func.func @vector_fma_3d (%a: vector <3 x2 x2 xf32 >) -> vector <3 x2 x2 xf32 >{
193192 %0 = vector.fma %a , %a , %a : vector <3 x2 x2 xf32 >
194193 return %0 : vector <3 x2 x2 xf32 >
195194}
196- // CHECK-LABEL: func @negative_vector_fma_3d
197- // CHECK-NOT: vector.extract_strided_slice
198- // CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
199- // CHECK: return
195+ // CHECK-LABEL: func @vector_fma_3d
196+ // CHECK-SAME: (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> {
197+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
198+ // CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
199+ // CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
200+ // CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
201+ // CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
202+ // CHECK: %[[E_OUT_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
203+ // CHECK: %[[S_OUT_0:.*]] = vector.shape_cast %[[E_OUT_0]] : vector<1x2x2xf32> to vector<2x2xf32>
204+ // CHECK: %[[FMA0:.*]] = vector.fma %[[S_LHS_0]], %[[S_RHS_0]], %[[S_OUT_0]] : vector<2x2xf32>
205+ // CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
206+ // CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
207+ // CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
208+ // CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
209+ // CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
210+ // CHECK: %[[E_OUT_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
211+ // CHECK: %[[S_OUT_1:.*]] = vector.shape_cast %[[E_OUT_1]] : vector<1x2x2xf32> to vector<2x2xf32>
212+ // CHECK: %[[FMA1:.*]] = vector.fma %[[S_LHS_1]], %[[S_RHS_1]], %[[S_OUT_1]] : vector<2x2xf32>
213+ // CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
214+ // CHECK: %[[E_LHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
215+ // CHECK: %[[S_LHS_2:.*]] = vector.shape_cast %[[E_LHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
216+ // CHECK: %[[E_RHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
217+ // CHECK: %[[S_RHS_2:.*]] = vector.shape_cast %[[E_RHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
218+ // CHECK: %[[E_OUT_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
219+ // CHECK: %[[S_OUT_2:.*]] = vector.shape_cast %[[E_OUT_2]] : vector<1x2x2xf32> to vector<2x2xf32>
220+ // CHECK: %[[FMA2:.*]] = vector.fma %[[S_LHS_2]], %[[S_RHS_2]], %[[S_OUT_2]] : vector<2x2xf32>
221+ // CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
222+ // CHECK: return %[[I2]] : vector<3x2x2xf32>
200223
201224func.func @vector_multi_reduction (%v : vector <4 x6 xf32 >, %acc: vector <4 xf32 >) -> vector <4 xf32 > {
202225 %0 = vector.multi_reduction #vector.kind <add >, %v , %acc [1 ] : vector <4 x6 xf32 > to vector <4 xf32 >
@@ -440,3 +463,36 @@ func.func @vector_step() -> vector<32xindex> {
440463// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
441464// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
442465// CHECK: return %[[INS3]] : vector<32xindex>
466+
467+
468+ func.func @elementwise_3D_to_2D (%v1: vector <2 x2 x2 xf32 >, %v2: vector <2 x2 x2 xf32 >) -> vector <2 x2 x2 xf32 > {
469+ %0 = arith.addf %v1 , %v2 : vector <2 x2 x2 xf32 >
470+ return %0 : vector <2 x2 x2 xf32 >
471+ }
472+ // CHECK-LABEL: func @elementwise_3D_to_2D
473+ // CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
474+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
475+ // CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
476+ // CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
477+ // CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
478+ // CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
479+ // CHECK: %[[ADD0:.*]] = arith.addf %[[S_LHS_0]], %[[S_RHS_0]] : vector<2x2xf32>
480+ // CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
481+ // CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
482+ // CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
483+ // CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
484+ // CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
485+ // CHECK: %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32>
486+ // CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
487+ // CHECK: return %[[I1]] : vector<2x2x2xf32>
488+
489+
490+ func.func @elementwise_4D_to_2D (%v1: vector <2 x2 x2 x2 xf32 >, %v2: vector <2 x2 x2 x2 xf32 >) -> vector <2 x2 x2 x2 xf32 > {
491+ %0 = arith.addf %v1 , %v2 : vector <2 x2 x2 x2 xf32 >
492+ return %0 : vector <2 x2 x2 x2 xf32 >
493+ }
494+
495+ // CHECK-LABEL: func @elementwise_4D_to_2D
496+ // CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
497+ // CHECK-NOT: arith.addf
498+ // CHECK: return
0 commit comments