@@ -151,43 +151,56 @@ func.func @extract_contract3(%arg0: vector<3xf32>,
151151 iterator_types = [" parallel" , " parallel" , " reduction" ]
152152}
153153
154- // CHECK-LABEL: func @extract_contract4
155- // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
156- // CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
157- // CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
158- // CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
159- // CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
160- // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
161- // CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
162- // CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
163- // CHECK: %[[T10:.*]] = vector.reduction <add>, %[[T9]] : vector<2xf32> into f32
164- // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
154+ // CHECK-LABEL: func @contract_to_dot_matmat
155+ // CHECK-SAME: %[[LHS:.*0]]: vector<2x2xf32>,
156+ // CHECK-SAME: %[[RHS:.*1]]: vector<2x2xf32>,
157+ // CHECK-SAME: %[[OUT:.*2]]: vector<2x2xf32>
165158//
166- // CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
167- // CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32>
168- // CHECK: %[[T20:.*]] = vector.reduction <add>, %[[T19]] : vector<2xf32> into f32
169- // CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
159+ // The `vector.contract` to dot lowering will 'unroll' a matrix-matrix
160+ // multiplication into individual dot products betweem rows of the LHS with columns
161+ // of the RHS. In the following test we expect 4 extract-dotproduct-insert sequences of
162+ // ops that correspond to the 4 dot products resulting from unrolling a matmul between
163+ // two matrices of size (2, 2).
170164//
171- // CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
172- // CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
173- // CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32>
174- // CHECK: %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32
175- // CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
165+ // CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
176166//
177- // CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
178- // CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32>
179- // CHECK: %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32
180- // CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
167+ // First, The RHS will be transposed to make it easier to extract individual columns
168+ // using vector.extract.
181169//
182- // CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32>
183- // CHECK: return %[[T52]] : vector<2x2xf32>
170+ // CHECK: %[[RHS_T:.*]] = vector.transpose %[[RHS]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
171+ //
172+ // Next, we expect 4 sequences of extracting rows of the RHS, LHS, performing a dot
173+ // product and then inserting it into the result.
174+ //
175+ // CHECK: %[[LHS0:.*]] = vector.extract %[[LHS]][0] : vector<2xf32> from vector<2x2xf32>
176+ // CHECK: %[[RHS_T0:.*]] = vector.extract %[[RHS_T]][0] : vector<2xf32> from vector<2x2xf32>
177+ // CHECK: %[[PROD0:.*]] = arith.mulf %[[LHS0]], %[[RHS_T0]] : vector<2xf32>
178+ // CHECK: %[[SUM0:.*]] = vector.reduction <add>, %[[PROD0]] : vector<2xf32> into f32
179+ // CHECK: %[[RES0:.*]] = vector.insert %[[SUM0]], %[[INIT]] [0, 0] : f32 into vector<2x2xf32>
180+ //
181+ // CHECK: %[[RHS_T1:.*]] = vector.extract %[[RHS_T]][1] : vector<2xf32> from vector<2x2xf32>
182+ // CHECK: %[[PROD1:.*]] = arith.mulf %[[LHS0]], %[[RHS_T1]] : vector<2xf32>
183+ // CHECK: %[[SUM1:.*]] = vector.reduction <add>, %[[PROD1]] : vector<2xf32> into f32
184+ // CHECK: %[[RES1:.*]] = vector.insert %[[SUM1]], %[[RES0]] [0, 1] : f32 into vector<2x2xf32>
185+ //
186+ // CHECK: %[[LHS1:.*]] = vector.extract %[[LHS]][1] : vector<2xf32> from vector<2x2xf32>
187+ // CHECK: %[[PROD2:.*]] = arith.mulf %[[LHS1]], %[[RHS_T0]] : vector<2xf32>
188+ // CHECK: %[[SUM2:.*]] = vector.reduction <add>, %[[PROD2]] : vector<2xf32> into f32
189+ // CHECK: %[[RES2:.*]] = vector.insert %[[SUM2]], %[[RES1]] [1, 0] : f32 into vector<2x2xf32>
190+ //
191+ // CHECK: %[[PROD3:.*]] = arith.mulf %[[LHS1]], %[[RHS_T1]] : vector<2xf32>
192+ // CHECK: %[[SUM3:.*]] = vector.reduction <add>, %[[PROD3]] : vector<2xf32> into f32
193+ // CHECK: %[[RES3:.*]] = vector.insert %[[SUM3]], %[[RES2]] [1, 1] : f32 into vector<2x2xf32>
194+ //
195+ // CHECK: %[[RES:.*]] = arith.addf %[[RES3]], %[[OUT]] : vector<2x2xf32>
196+ // CHECK: return %[[RES]] : vector<2x2xf32>
184197
185- func.func @extract_contract4 ( %arg0 : vector <2 x2 xf32 >,
186- %arg1 : vector <2 x2 xf32 >,
187- %arg2 : vector <2 x2 xf32 >) -> vector <2 x2 xf32 > {
188- %0 = vector.contract #matmat_trait %arg0 , %arg1 , %arg2
198+ func.func @contract_to_dot_matmat ( %lhs : vector <2 x2 xf32 >,
199+ %rhs : vector <2 x2 xf32 >,
200+ %init : vector <2 x2 xf32 >) -> vector <2 x2 xf32 > {
201+ %res = vector.contract #matmat_trait %lhs , %rhs , %init
189202 : vector <2 x2 xf32 >, vector <2 x2 xf32 > into vector <2 x2 xf32 >
190- return %0 : vector <2 x2 xf32 >
203+ return %res : vector <2 x2 xf32 >
191204}
192205
193206
0 commit comments