@@ -193,30 +193,31 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
193193 return %0 : vector <3 x2 x2 xf32 >
194194}
195195// CHECK-LABEL: func @vector_fma_3d
196+ // CHECK-SAME: (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> {
196197// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
197- // CHECK: %[[E0 :.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
198- // CHECK: %[[S0 :.*]] = vector.shape_cast %[[E0 ]] : vector<1x2x2xf32> to vector<2x2xf32>
199- // CHECK: %[[E1 :.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
200- // CHECK: %[[S1 :.*]] = vector.shape_cast %[[E1 ]] : vector<1x2x2xf32> to vector<2x2xf32>
201- // CHECK: %[[E2 :.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
202- // CHECK: %[[S2 :.*]] = vector.shape_cast %[[E2 ]] : vector<1x2x2xf32> to vector<2x2xf32>
203- // CHECK: %[[FMA0:.*]] = vector.fma %[[S0 ]], %[[S1 ]], %[[S2 ]] : vector<2x2xf32>
198+ // CHECK: %[[E_LHS_0 :.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
199+ // CHECK: %[[S_LHS_0 :.*]] = vector.shape_cast %[[E_LHS_0 ]] : vector<1x2x2xf32> to vector<2x2xf32>
200+ // CHECK: %[[E_RHS_0 :.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
201+ // CHECK: %[[S_RHS_0 :.*]] = vector.shape_cast %[[E_RHS_0 ]] : vector<1x2x2xf32> to vector<2x2xf32>
202+ // CHECK: %[[E_OUT_0 :.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
203+ // CHECK: %[[S_OUT_0 :.*]] = vector.shape_cast %[[E_OUT_0 ]] : vector<1x2x2xf32> to vector<2x2xf32>
204+ // CHECK: %[[FMA0:.*]] = vector.fma %[[S_LHS_0 ]], %[[S_RHS_0 ]], %[[S_OUT_0 ]] : vector<2x2xf32>
204205// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
205- // CHECK: %[[E3 :.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
206- // CHECK: %[[S3 :.*]] = vector.shape_cast %[[E3 ]] : vector<1x2x2xf32> to vector<2x2xf32>
207- // CHECK: %[[E4 :.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
208- // CHECK: %[[S4 :.*]] = vector.shape_cast %[[E4 ]] : vector<1x2x2xf32> to vector<2x2xf32>
209- // CHECK: %[[E5 :.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
210- // CHECK: %[[S5 :.*]] = vector.shape_cast %[[E5 ]] : vector<1x2x2xf32> to vector<2x2xf32>
211- // CHECK: %[[FMA1:.*]] = vector.fma %[[S3 ]], %[[S4 ]], %[[S5 ]] : vector<2x2xf32>
206+ // CHECK: %[[E_LHS_1 :.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
207+ // CHECK: %[[S_LHS_1 :.*]] = vector.shape_cast %[[E_LHS_1 ]] : vector<1x2x2xf32> to vector<2x2xf32>
208+ // CHECK: %[[E_RHS_1 :.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
209+ // CHECK: %[[S_RHS_1 :.*]] = vector.shape_cast %[[E_RHS_1 ]] : vector<1x2x2xf32> to vector<2x2xf32>
210+ // CHECK: %[[E_OUT_1 :.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
211+ // CHECK: %[[S_OUT_1 :.*]] = vector.shape_cast %[[E_OUT_1 ]] : vector<1x2x2xf32> to vector<2x2xf32>
212+ // CHECK: %[[FMA1:.*]] = vector.fma %[[S_LHS_1 ]], %[[S_RHS_1 ]], %[[S_OUT_1 ]] : vector<2x2xf32>
212213// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
213- // CHECK: %[[E6 :.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
214- // CHECK: %[[S6 :.*]] = vector.shape_cast %[[E6 ]] : vector<1x2x2xf32> to vector<2x2xf32>
215- // CHECK: %[[E7 :.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
216- // CHECK: %[[S7 :.*]] = vector.shape_cast %[[E7 ]] : vector<1x2x2xf32> to vector<2x2xf32>
217- // CHECK: %[[E8 :.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
218- // CHECK: %[[S8 :.*]] = vector.shape_cast %[[E8 ]] : vector<1x2x2xf32> to vector<2x2xf32>
219- // CHECK: %[[FMA2:.*]] = vector.fma %[[S6 ]], %[[S7 ]], %[[S8 ]] : vector<2x2xf32>
214+ // CHECK: %[[E_LHS_2 :.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
215+ // CHECK: %[[S_LHS_2 :.*]] = vector.shape_cast %[[E_LHS_2 ]] : vector<1x2x2xf32> to vector<2x2xf32>
216+ // CHECK: %[[E_RHS_2 :.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
217+ // CHECK: %[[S_RHS_2 :.*]] = vector.shape_cast %[[E_RHS_2 ]] : vector<1x2x2xf32> to vector<2x2xf32>
218+ // CHECK: %[[E_OUT_2 :.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
219+ // CHECK: %[[S_OUT_2 :.*]] = vector.shape_cast %[[E_OUT_2 ]] : vector<1x2x2xf32> to vector<2x2xf32>
220+ // CHECK: %[[FMA2:.*]] = vector.fma %[[S_LHS_2 ]], %[[S_RHS_2 ]], %[[S_OUT_2 ]] : vector<2x2xf32>
220221// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
221222// CHECK: return %[[I2]] : vector<3x2x2xf32>
222223
0 commit comments