11// RUN: mlir-opt %s -test-vector-reduction-to-contract-patterns -split-input-file | FileCheck %s
22
3- // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4- // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
3+ // TODO: Seperate tests for vector.multi_reduction -> vector.contract and
4+ // * pre-op + vector.contract -> vector.contract,
5+ // * vector.contract + post-op -> vector.contract.
6+
7+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
8+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
59
610// CHECK-LABEL: multidimreduction_contract
711// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xf32>, %[[ARG1:.*]]: vector<8x32x16xf32>, %[[ARG2:.*]]: vector<8x16xf32>)
8- // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0 ]], #[[$map0 ]], #[[$map1 ]]],
12+ // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0 ]], #[[$MAP0 ]], #[[$MAP1 ]]],
913// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
1014// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
1115// CHECK-NEXT: return %[[R]] : vector<8x16xf32>
@@ -17,12 +21,12 @@ func.func @multidimreduction_contract(
1721
1822// -----
1923
20- // CHECK-DAG: #[[$map0 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
21- // CHECK-DAG: #[[$map1 :.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
24+ // CHECK-DAG: #[[$MAP0 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
25+ // CHECK-DAG: #[[$MAP1 :.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
2226
2327// CHECK-LABEL: multidimreduction_contract_int
2428// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xi32>, %[[ARG1:.*]]: vector<8x32x16xi32>, %[[ARG2:.*]]: vector<8x16xi32>)
25- // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0 ]], #[[$map0 ]], #[[$map1 ]]],
29+ // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0 ]], #[[$MAP0 ]], #[[$MAP1 ]]],
2630// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
2731// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
2832// CHECK-NEXT: return %[[R]] : vector<8x16xi32>
@@ -35,17 +39,21 @@ func.func @multidimreduction_contract_int(
3539
3640// -----
3741
42+ //-----------------------------------------------------------------------------
43+ // [Pattern: CombineContractABTranspose]
44+ //-----------------------------------------------------------------------------
45+
3846#map0 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
3947#map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
4048
41- // CHECK-DAG: #[[$map0 :.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
42- // CHECK-DAG: #[[$map1 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
43- // CHECK-DAG: #[[$map2 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
49+ // CHECK-DAG: #[[$MAP0 :.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
50+ // CHECK-DAG: #[[$MAP1 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
51+ // CHECK-DAG: #[[$MAP2 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
4452
4553// CHECK-LABEL: contract_transpose
4654// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>,
4755// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
48- // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0 ]], #[[$map1 ]], #[[$map2 ]]],
56+ // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0 ]], #[[$MAP1 ]], #[[$MAP2 ]]],
4957// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
5058// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32>
5159// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
@@ -68,14 +76,14 @@ func.func @contract_transpose(
6876#map0 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
6977#map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
7078
71- // CHECK-DAG: #[[$map0 :.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
72- // CHECK-DAG: #[[$map1 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
73- // CHECK-DAG: #[[$map2 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
79+ // CHECK-DAG: #[[$MAP0 :.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
80+ // CHECK-DAG: #[[$MAP1 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
81+ // CHECK-DAG: #[[$MAP2 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
7482
7583// CHECK-LABEL: contract_broadcast
7684// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>,
7785// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
78- // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0 ]], #[[$map1 ]], #[[$map2 ]]],
86+ // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0 ]], #[[$MAP1 ]], #[[$MAP2 ]]],
7987// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
8088// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
8189// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
@@ -127,6 +135,42 @@ func.func @contract_broadcast_masked(
127135
128136// -----
129137
138+ // Same as above, but with a scalable dim.
139+
140+ #map0 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
141+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
142+
143+ // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
144+ // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
145+ // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
146+
147+ // CHECK-LABEL: contract_broadcast_masked_scalable
148+ // CHECK-SAME: %[[ARG0:.*]]: vector<[32]x16xf32>,
149+ // CHECK-SAME: %[[ARG1:.*]]: vector<8x[32]x16xf32>,
150+ // CHECK-SAME: %[[MASK:.*]]: vector<8x[32]x16xi1>) -> vector<8x32xf32> {
151+ // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
152+ // CHECK: %[[R:.*]] = vector.mask %[[MASK]] {
153+ // CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
154+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"],
155+ // CHECK-SAME: kind = #vector.kind<add>}
156+ // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[C0]] : vector<[32]x16xf32>, vector<8x[32]x16xf32> into vector<8x32xf32>
157+ // CHECK-SAME } : vector<8x[32]x16xi1> -> vector<8x32xf32>
158+ // CHECK: return %[[R]] : vector<8x32xf32>
159+ func.func @contract_broadcast_masked_scalable (
160+ %arg0: vector <[32 ]x16 xf32 >, %arg1: vector <8 x[32 ]x16 xf32 >, %mask: vector <8 x[32 ]x16 xi1 >) -> vector <8 x32 xf32 > {
161+ %cst = arith.constant dense <0.000000e+00 > : vector <8 x32 xf32 >
162+ %0 = vector.broadcast %arg0 : vector <[32 ]x16 xf32 > to vector <8 x[32 ]x16 xf32 >
163+ %1 = vector.mask %mask {
164+ vector.contract {index ing_maps = [#map0 , #map0 , #map1 ],
165+ iterator_types = [" parallel" , " parallel" , " reduction" ],
166+ kind = #vector.kind <add >
167+ } %0 , %arg1 , %cst : vector <8 x[32 ]x16 xf32 >, vector <8 x[32 ]x16 xf32 > into vector <8 x32 xf32 >
168+ } : vector <8 x[32 ]x16 xi1 > -> vector <8 x32 xf32 >
169+ return %1 : vector <8 x32 xf32 >
170+ }
171+
172+ // -----
173+
130174// Test that CombineContractBroadcast is able to combine a broadcast that
131175// creates a unit dim that is consumed by a reduction iterator, dropping that
132176// reduction iterator, as long as there is another reduction iterator left.
@@ -135,14 +179,14 @@ func.func @contract_broadcast_masked(
135179#map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>
136180#map2 = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>
137181
138- // CHECK-DAG: #[[$map0 :.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
139- // CHECK-DAG: #[[$map1 :.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
140- // CHECK-DAG: #[[$map2 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
182+ // CHECK-DAG: #[[$MAP0 :.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
183+ // CHECK-DAG: #[[$MAP1 :.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
184+ // CHECK-DAG: #[[$MAP2 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
141185
142186// CHECK-LABEL: contract_broadcast_unit_dim_reduction
143187// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
144188// CHECK: vector.contract
145- // CHECK-SAME: indexing_maps = [#[[$map0 ]], #[[$map1 ]], #[[$map2 ]]]
189+ // CHECK-SAME: indexing_maps = [#[[$MAP0 ]], #[[$MAP1 ]], #[[$MAP2 ]]]
146190// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
147191// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
148192func.func @contract_broadcast_unit_dim_reduction (%arg0 : vector <8 x4 xi32 >, %arg1 : vector <8 x4 xi32 >, %arg2 : vector <8 x8 xi32 >) -> vector <8 x8 xi32 > {
@@ -158,22 +202,55 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1
158202
159203// -----
160204
161- // Same as above, but with a mask
205+ // Same as above, but with a mask.
162206
163207#map0 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>
164208#map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>
165209#map2 = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>
166210
167- // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
168- // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
169- // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
211+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
212+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
213+ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
214+
215+ // CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable
216+ // CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32>, %[[ARG2:.+]]: vector<8x[8]xi32>, %[[MASK:.+]]: vector<1x8x[8]x4xi1>)
217+ // CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1> to vector<8x[8]x4xi1>
218+ // CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
219+ // CHECK-SAME: vector.contract
220+ // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
221+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
222+ // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32> into vector<8x[8]xi32>
223+ func.func @contract_broadcast_unit_dim_reduction_masked_scalable (%arg0 : vector <8 x4 xi32 >, %arg1 : vector <[8 ]x4 xi32 >, %arg2 : vector <8 x[8 ]xi32 >, %mask: vector <1 x8 x[8 ]x4 xi1 >) -> vector <8 x[8 ]xi32 > {
224+ %0 = vector.broadcast %arg0 : vector <8 x4 xi32 > to vector <1 x8 x4 xi32 >
225+ %1 = vector.broadcast %arg1 : vector <[8 ]x4 xi32 > to vector <1 x[8 ]x4 xi32 >
226+ %result = vector.mask %mask {
227+ vector.contract {
228+ indexing_maps = [#map0 , #map1 , #map2 ],
229+ iterator_types = [" reduction" , " parallel" , " parallel" , " reduction" ],
230+ kind = #vector.kind <add >
231+ } %0 , %1 , %arg2 : vector <1 x8 x4 xi32 >, vector <1 x[8 ]x4 xi32 > into vector <8 x[8 ]xi32 >
232+ } : vector <1 x8 x[8 ]x4 xi1 > -> vector <8 x[8 ]xi32 >
233+ return %result : vector <8 x[8 ]xi32 >
234+ }
235+
236+ // -----
237+
238+ // Same as above, but with a scalable dim.
239+
240+ #map0 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>
241+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>
242+ #map2 = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>
243+
244+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
245+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
246+ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
170247
171248// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
172249// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>, %[[MASK:.+]]: vector<1x8x8x4xi1>)
173250// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1> to vector<8x8x4xi1>
174251// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
175252// CHECK-SAME: vector.contract
176- // CHECK-SAME: indexing_maps = [#[[$map0 ]], #[[$map1 ]], #[[$map2 ]]]
253+ // CHECK-SAME: indexing_maps = [#[[$MAP0 ]], #[[$MAP1 ]], #[[$MAP2 ]]]
177254// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
178255// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
179256func.func @contract_broadcast_unit_dim_reduction_masked (%arg0 : vector <8 x4 xi32 >, %arg1 : vector <8 x4 xi32 >, %arg2 : vector <8 x8 xi32 >, %mask: vector <1 x8 x8 x4 xi1 >) -> vector <8 x8 xi32 > {
@@ -200,16 +277,16 @@ func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>,
200277#map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 , d3 )>
201278#map2 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 )>
202279
203- // CHECK-DAG: #[[$map0 :.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
204- // CHECK-DAG: #[[$map1 :.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
205- // CHECK-DAG: #[[$map2 :.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
280+ // CHECK-DAG: #[[$MAP0 :.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
281+ // CHECK-DAG: #[[$MAP1 :.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
282+ // CHECK-DAG: #[[$MAP2 :.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
206283
207284// CHECK-LABEL: contract_broadcast_non_unit_dim_reduction_with_permutation
208285// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
209286// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8x4xi32> to vector<2x8x4xi32>
210287// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8x4xi32> to vector<2x8x4xi32>
211288// CHECK: vector.contract
212- // CHECK-SAME: indexing_maps = [#[[$map0 ]], #[[$map1 ]], #[[$map2 ]]]
289+ // CHECK-SAME: indexing_maps = [#[[$MAP0 ]], #[[$MAP1 ]], #[[$MAP2 ]]]
213290// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel", "reduction"]
214291// CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<2x8x4xi32>, vector<2x8x4xi32> into vector<8x8xi32>
215292func.func @contract_broadcast_non_unit_dim_reduction_with_permutation (%arg0 : vector <8 x4 xi32 >, %arg1 : vector <8 x4 xi32 >, %arg2 : vector <8 x8 xi32 >) -> vector <8 x8 xi32 > {
@@ -232,16 +309,16 @@ func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : ve
232309#map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
233310#map2 = affine_map <(d0 , d1 , d2 ) -> (d1 , d2 )>
234311
235- // CHECK-DAG: #[[$map0 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
236- // CHECK-DAG: #[[$map1 :.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
237- // CHECK-DAG: #[[$map2 :.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
312+ // CHECK-DAG: #[[$MAP0 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
313+ // CHECK-DAG: #[[$MAP1 :.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
314+ // CHECK-DAG: #[[$MAP2 :.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
238315
239316// CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction
240317// CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
241318// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32>
242319// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8xi32> to vector<1x8xi32>
243320// CHECK: vector.contract
244- // CHECK-SAME: indexing_maps = [#[[$map0 ]], #[[$map1 ]], #[[$map2 ]]]
321+ // CHECK-SAME: indexing_maps = [#[[$MAP0 ]], #[[$MAP1 ]], #[[$MAP2 ]]]
245322// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel"]
246323// CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x8xi32>, vector<1x8xi32> into vector<8x8xi32>
247324func.func @contract_broadcast_unit_dim_reduction_as_only_reduction (%arg0 : vector <8 xi32 >, %arg1 : vector <8 xi32 >, %arg2 : vector <8 x8 xi32 >) -> vector <8 x8 xi32 > {
@@ -264,15 +341,15 @@ func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vecto
264341#map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
265342#map2 = affine_map <(d0 , d1 , d2 ) -> (d1 )>
266343
267- // CHECK-DAG: #[[$map0 :.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
268- // CHECK-DAG: #[[$map1 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
269- // CHECK-DAG: #[[$map2 :.*]] = affine_map<(d0, d1, d2) -> (d1)>
344+ // CHECK-DAG: #[[$MAP0 :.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
345+ // CHECK-DAG: #[[$MAP1 :.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
346+ // CHECK-DAG: #[[$MAP2 :.*]] = affine_map<(d0, d1, d2) -> (d1)>
270347
271348// CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs
272349// CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>)
273350// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32>
274351// CHECK: vector.contract
275- // CHECK-SAME: indexing_maps = [#[[$map0 ]], #[[$map1 ]], #[[$map2 ]]]
352+ // CHECK-SAME: indexing_maps = [#[[$MAP0 ]], #[[$MAP1 ]], #[[$MAP2 ]]]
276353// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
277354// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
278355
@@ -303,7 +380,7 @@ func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vecto
303380// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>, %[[ARG2:.+]]: vector<1xf32>)
304381// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<1xf32> to vector<1x1xf32>
305382// CHECK: vector.contract
306- // CHECK-SAME: indexing_maps = [#[[$map0 ]], #[[$map1 ]], #[[$map2 ]]]
383+ // CHECK-SAME: indexing_maps = [#[[$MAP0 ]], #[[$MAP1 ]], #[[$MAP2 ]]]
307384// CHECK-SAME: iterator_types = ["parallel", "reduction"]
308385// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1xf32>, vector<1x1xf32> into vector<1xf32>
309386
@@ -320,6 +397,10 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
320397
321398// -----
322399
400+ //-----------------------------------------------------------------------------
401+ // [Pattern: CombineContractResultTranspose]
402+ //-----------------------------------------------------------------------------
403+
323404// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
324405// CHECK-DAG: #[[$RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
325406// CHECK-DAG: #[[$ACC_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
0 commit comments