Skip to content

Commit f3e3aff

Browse files
committed
fixup! [mlir][vector] Update CombineContractBroadcastMask
* Add tests for scalable vectors * Capitalize all LIT variables used for maps * Fix punctuation
1 parent 568f955 commit f3e3aff

File tree

1 file changed

+117
-36
lines changed

1 file changed

+117
-36
lines changed

mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Lines changed: 117 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
// RUN: mlir-opt %s -test-vector-reduction-to-contract-patterns -split-input-file | FileCheck %s
22

3-
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4-
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
3+
// TODO: Seperate tests for vector.multi_reduction -> vector.contract and
4+
// * pre-op + vector.contract -> vector.contract,
5+
// * vector.contract + post-op -> vector.contract.
6+
7+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
8+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
59

610
// CHECK-LABEL: multidimreduction_contract
711
// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xf32>, %[[ARG1:.*]]: vector<8x32x16xf32>, %[[ARG2:.*]]: vector<8x16xf32>)
8-
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
12+
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]],
913
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
1014
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
1115
// CHECK-NEXT: return %[[R]] : vector<8x16xf32>
@@ -17,12 +21,12 @@ func.func @multidimreduction_contract(
1721

1822
// -----
1923

20-
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
21-
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
24+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
25+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
2226

2327
// CHECK-LABEL: multidimreduction_contract_int
2428
// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xi32>, %[[ARG1:.*]]: vector<8x32x16xi32>, %[[ARG2:.*]]: vector<8x16xi32>)
25-
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
29+
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]],
2630
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
2731
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
2832
// CHECK-NEXT: return %[[R]] : vector<8x16xi32>
@@ -35,17 +39,21 @@ func.func @multidimreduction_contract_int(
3539

3640
// -----
3741

42+
//-----------------------------------------------------------------------------
43+
// [Pattern: CombineContractABTranspose]
44+
//-----------------------------------------------------------------------------
45+
3846
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
3947
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
4048

41-
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
42-
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
43-
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
49+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
50+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
51+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
4452

4553
// CHECK-LABEL: contract_transpose
4654
// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>,
4755
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
48-
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
56+
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
4957
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
5058
// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32>
5159
// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
@@ -68,14 +76,14 @@ func.func @contract_transpose(
6876
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
6977
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
7078

71-
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
72-
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
73-
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
79+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
80+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
81+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
7482

7583
// CHECK-LABEL: contract_broadcast
7684
// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>,
7785
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
78-
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
86+
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
7987
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
8088
// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
8189
// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
@@ -127,6 +135,42 @@ func.func @contract_broadcast_masked(
127135

128136
// -----
129137

138+
// Same as above, but with a scalable dim.
139+
140+
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
141+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
142+
143+
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
144+
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
145+
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
146+
147+
// CHECK-LABEL: contract_broadcast_masked_scalable
148+
// CHECK-SAME: %[[ARG0:.*]]: vector<[32]x16xf32>,
149+
// CHECK-SAME: %[[ARG1:.*]]: vector<8x[32]x16xf32>,
150+
// CHECK-SAME: %[[MASK:.*]]: vector<8x[32]x16xi1>) -> vector<8x32xf32> {
151+
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
152+
// CHECK: %[[R:.*]] = vector.mask %[[MASK]] {
153+
// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
154+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"],
155+
// CHECK-SAME: kind = #vector.kind<add>}
156+
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[C0]] : vector<[32]x16xf32>, vector<8x[32]x16xf32> into vector<8x32xf32>
157+
// CHECK-SAME } : vector<8x[32]x16xi1> -> vector<8x32xf32>
158+
// CHECK: return %[[R]] : vector<8x32xf32>
159+
func.func @contract_broadcast_masked_scalable(
160+
%arg0: vector<[32]x16xf32>, %arg1: vector<8x[32]x16xf32>, %mask: vector<8x[32]x16xi1>) -> vector<8x32xf32> {
161+
%cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
162+
%0 = vector.broadcast %arg0 : vector<[32]x16xf32> to vector<8x[32]x16xf32>
163+
%1 = vector.mask %mask {
164+
vector.contract {indexing_maps = [#map0, #map0, #map1],
165+
iterator_types = ["parallel", "parallel", "reduction"],
166+
kind = #vector.kind<add>
167+
} %0, %arg1, %cst : vector<8x[32]x16xf32>, vector<8x[32]x16xf32> into vector<8x32xf32>
168+
} : vector<8x[32]x16xi1> -> vector<8x32xf32>
169+
return %1 : vector<8x32xf32>
170+
}
171+
172+
// -----
173+
130174
// Test that CombineContractBroadcast is able to combine a broadcast that
131175
// creates a unit dim that is consumed by a reduction iterator, dropping that
132176
// reduction iterator, as long as there is another reduction iterator left.
@@ -135,14 +179,14 @@ func.func @contract_broadcast_masked(
135179
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
136180
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
137181

138-
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
139-
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
140-
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
182+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
183+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
184+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
141185

142186
// CHECK-LABEL: contract_broadcast_unit_dim_reduction
143187
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
144188
// CHECK: vector.contract
145-
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
189+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
146190
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
147191
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
148192
func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> {
@@ -158,22 +202,55 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1
158202

159203
// -----
160204

161-
// Same as above, but with a mask
205+
// Same as above, but with a mask.
162206

163207
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
164208
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
165209
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
166210

167-
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
168-
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
169-
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
211+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
212+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
213+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
214+
215+
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable
216+
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32>, %[[ARG2:.+]]: vector<8x[8]xi32>, %[[MASK:.+]]: vector<1x8x[8]x4xi1>)
217+
// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1> to vector<8x[8]x4xi1>
218+
// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
219+
// CHECK-SAME: vector.contract
220+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
221+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
222+
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32> into vector<8x[8]xi32>
223+
func.func @contract_broadcast_unit_dim_reduction_masked_scalable(%arg0 : vector<8x4xi32>, %arg1 : vector<[8]x4xi32>, %arg2 : vector<8x[8]xi32>, %mask: vector<1x8x[8]x4xi1>) -> vector<8x[8]xi32> {
224+
%0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
225+
%1 = vector.broadcast %arg1 : vector<[8]x4xi32> to vector<1x[8]x4xi32>
226+
%result = vector.mask %mask {
227+
vector.contract {
228+
indexing_maps = [#map0, #map1, #map2],
229+
iterator_types = ["reduction", "parallel", "parallel", "reduction"],
230+
kind = #vector.kind<add>
231+
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x[8]x4xi32> into vector<8x[8]xi32>
232+
} : vector<1x8x[8]x4xi1> -> vector<8x[8]xi32>
233+
return %result : vector<8x[8]xi32>
234+
}
235+
236+
// -----
237+
238+
// Same as above, but with a scalable dim.
239+
240+
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
241+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
242+
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
243+
244+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
245+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
246+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
170247

171248
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
172249
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>, %[[MASK:.+]]: vector<1x8x8x4xi1>)
173250
// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1> to vector<8x8x4xi1>
174251
// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
175252
// CHECK-SAME: vector.contract
176-
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
253+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
177254
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
178255
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
179256
func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
@@ -200,16 +277,16 @@ func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>,
200277
#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
201278
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
202279

203-
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
204-
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
205-
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
280+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
281+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
282+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
206283

207284
// CHECK-LABEL: contract_broadcast_non_unit_dim_reduction_with_permutation
208285
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
209286
// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8x4xi32> to vector<2x8x4xi32>
210287
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8x4xi32> to vector<2x8x4xi32>
211288
// CHECK: vector.contract
212-
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
289+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
213290
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel", "reduction"]
214291
// CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<2x8x4xi32>, vector<2x8x4xi32> into vector<8x8xi32>
215292
func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> {
@@ -232,16 +309,16 @@ func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : ve
232309
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
233310
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
234311

235-
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
236-
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
237-
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
312+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
313+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
314+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
238315

239316
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction
240317
// CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
241318
// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32>
242319
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8xi32> to vector<1x8xi32>
243320
// CHECK: vector.contract
244-
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
321+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
245322
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel"]
246323
// CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x8xi32>, vector<1x8xi32> into vector<8x8xi32>
247324
func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> {
@@ -264,15 +341,15 @@ func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vecto
264341
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
265342
#map2 = affine_map<(d0, d1, d2) -> (d1)>
266343

267-
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
268-
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
269-
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
344+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
345+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
346+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
270347

271348
// CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs
272349
// CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>)
273350
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32>
274351
// CHECK: vector.contract
275-
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
352+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
276353
// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
277354
// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
278355

@@ -303,7 +380,7 @@ func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vecto
303380
// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>, %[[ARG2:.+]]: vector<1xf32>)
304381
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<1xf32> to vector<1x1xf32>
305382
// CHECK: vector.contract
306-
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
383+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
307384
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
308385
// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1xf32>, vector<1x1xf32> into vector<1xf32>
309386

@@ -320,6 +397,10 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
320397

321398
// -----
322399

400+
//-----------------------------------------------------------------------------
401+
// [Pattern: CombineContractResultTranspose]
402+
//-----------------------------------------------------------------------------
403+
323404
// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
324405
// CHECK-DAG: #[[$RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
325406
// CHECK-DAG: #[[$ACC_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>

0 commit comments

Comments
 (0)