Skip to content

Commit e60d973

Browse files
committed
Feedback
1 parent 3e1c7e4 commit e60d973

File tree

2 files changed

+63
-65
lines changed

2 files changed

+63
-65
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class VectorShuffleTreeBuilder {
146146
// Shuffle tree configuration.
147147
unsigned numLevels;
148148
SmallVector<unsigned> vectorSizePerLevel;
149-
/// Holds the range of positions each vector in the tree contributes to the
149+
/// Holds the range of positions each vector in the tree contributes to in the
150150
/// final output vector.
151151
SmallVector<SmallVector<Interval>> intervalsPerLevel;
152152

@@ -180,8 +180,8 @@ static void duplicateLastIfOdd(SmallVectorImpl<T> &values) {
180180
// ===---------------------------------------------------------------------===//
181181

182182
/// Compute the intervals for all the vectors in the shuffle tree. The interval
183-
/// interval of a vector is the range of positions that the vector contributes
184-
/// to the final output vector.
183+
/// of a vector is the range of positions that the vector contributes to in the
184+
/// final output vector.
185185
///
186186
/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
187187
///
@@ -245,8 +245,6 @@ void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() {
245245
intervalsPerLevel.push_back(std::move(firstLevelIntervals));
246246

247247
// Compute intervals for the remaining levels.
248-
unsigned outputNumElements =
249-
cast<VectorType>(fromElemsOp.getResult().getType()).getNumElements();
250248
for (unsigned level = 1; level < numLevels; ++level) {
251249
bool isLastLevel = level == numLevels - 1;
252250
const auto &prevLevelIntervals = intervalsPerLevel[level - 1];
@@ -569,12 +567,11 @@ static SmallVector<int64_t> computePropagationShuffleMask(
569567
///
570568
/// The code generation consists of combining pairs of vectors at each level of
571569
/// the tree, using the pre-computed tree intervals and vector sizes. The
572-
/// algorithm generates two kinds of shuffle masks: permutation masks and
573-
/// permutation masks and propagation masks:
574-
/// * Permutation masks are computed for the first level of the tree and
575-
/// permute the input vector elements to their relative position in the
576-
/// final output.
577-
/// * Propagation masks are computed for subsequent levels and propagate the
570+
/// algorithm generates two kinds of shuffle masks:
571+
/// * Permutation masks: computed for the first level of the tree and permute
572+
/// the input vector elements to their relative position in the final
573+
/// output.
574+
/// * Propagation masks: computed for subsequent levels and propagate the
578575
/// elements to the next level without permutation.
579576
///
580577
/// For further details on the shuffle mask computation, please, take a look at

mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
// where L# refers to the level of the tree the shuffle belongs to, and SH# refers to
55
// the shuffle index within that level.
66

7-
func.func @trivial_forwarding(%a: vector<8xf32>) -> vector<8xf32> {
7+
func.func @unsupported_trivial_forwarding(%a: vector<8xf32>) -> vector<8xf32> {
88
%0:8 = vector.to_elements %a : vector<8xf32>
99
%1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : vector<8xf32>
1010
return %1 : vector<8xf32>
1111
}
1212

1313
// No shuffle tree needed for trivial forwarding case.
1414

15-
// CHECK-LABEL: func @trivial_forwarding(
15+
// CHECK-LABEL: func @unsupported_trivial_forwarding(
1616
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
1717
// CHECK: return %[[A]] : vector<8xf32>
1818

@@ -26,6 +26,10 @@ func.func @unsupported_multi_dim_vector_inputs(%a: vector<2x4xf32>, %b: vector<2
2626
return %2 : vector<4xf32>
2727
}
2828

29+
// CHECK-LABEL: func @unsupported_multi_dim_vector_inputs(
30+
// CHECK-COUNT-2: vector.to_elements
31+
// CHECK: vector.from_elements
32+
2933
// -----
3034

3135
func.func @unsupported_multi_dim_vector_output(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<2x2xf32> {
@@ -36,40 +40,44 @@ func.func @unsupported_multi_dim_vector_output(%a: vector<8xf32>, %b: vector<8xf
3640
return %2 : vector<2x2xf32>
3741
}
3842

43+
// CHECK-LABEL: func @unsupported_multi_dim_vector_output(
44+
// CHECK-COUNT-2: vector.to_elements
45+
// CHECK: vector.from_elements
46+
3947
// -----
4048

41-
func.func @single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> {
49+
func.func @shuffle_tree_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> {
4250
%0:8 = vector.to_elements %a : vector<8xf32>
4351
%1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32>
4452
return %1 : vector<8xf32>
4553
}
4654

47-
// CHECK-LABEL: func @single_input_shuffle(
55+
// CHECK-LABEL: func @shuffle_tree_single_input_shuffle(
4856
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
4957
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32>
5058
// CHECK: return %[[L0SH0]]
5159

5260
// -----
5361

54-
func.func @from_elements_to_elements_single_shuffle(%a: vector<8xf32>,
55-
%b: vector<8xf32>) -> vector<8xf32> {
62+
func.func @shuffle_tree_single_shuffle(%a: vector<8xf32>,
63+
%b: vector<8xf32>) -> vector<8xf32> {
5664
%0:8 = vector.to_elements %a : vector<8xf32>
5765
%1:8 = vector.to_elements %b : vector<8xf32>
5866
%2 = vector.from_elements %0#7, %1#0, %0#6, %1#1, %0#5, %1#2, %0#4, %1#3 : vector<8xf32>
5967
return %2 : vector<8xf32>
6068
}
6169

62-
// CHECK-LABEL: func @from_elements_to_elements_single_shuffle(
70+
// CHECK-LABEL: func @shuffle_tree_single_shuffle(
6371
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
6472
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [7, 8, 6, 9, 5, 10, 4, 11] : vector<8xf32>
6573
// CHECK: return %[[L0SH0]]
6674

6775
// -----
6876

6977
func.func @shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
70-
%b: vector<8xf32>,
71-
%c: vector<8xf32>,
72-
%d: vector<8xf32>) -> vector<32xf32> {
78+
%b: vector<8xf32>,
79+
%c: vector<8xf32>,
80+
%d: vector<8xf32>) -> vector<32xf32> {
7381
%0:8 = vector.to_elements %a : vector<8xf32>
7482
%1:8 = vector.to_elements %b : vector<8xf32>
7583
%2:8 = vector.to_elements %c : vector<8xf32>
@@ -109,23 +117,22 @@ func.func @shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
109117

110118
// -----
111119

112-
func.func @shuffle_tree_concat_64x4_256(
113-
%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>, %d: vector<4xf32>,
114-
%e: vector<4xf32>, %f: vector<4xf32>, %g: vector<4xf32>, %h: vector<4xf32>,
115-
%i: vector<4xf32>, %j: vector<4xf32>, %k: vector<4xf32>, %l: vector<4xf32>,
116-
%m: vector<4xf32>, %n: vector<4xf32>, %o: vector<4xf32>, %p: vector<4xf32>,
117-
%q: vector<4xf32>, %r: vector<4xf32>, %s: vector<4xf32>, %t: vector<4xf32>,
118-
%u: vector<4xf32>, %v: vector<4xf32>, %w: vector<4xf32>, %x: vector<4xf32>,
119-
%y: vector<4xf32>, %z: vector<4xf32>, %aa: vector<4xf32>, %ab: vector<4xf32>,
120-
%ac: vector<4xf32>, %ad: vector<4xf32>, %ae: vector<4xf32>, %af: vector<4xf32>,
121-
%ag: vector<4xf32>, %ah: vector<4xf32>, %ai: vector<4xf32>, %aj: vector<4xf32>,
122-
%ak: vector<4xf32>, %al: vector<4xf32>, %am: vector<4xf32>, %an: vector<4xf32>,
123-
%ao: vector<4xf32>, %ap: vector<4xf32>, %aq: vector<4xf32>, %ar: vector<4xf32>,
124-
%as: vector<4xf32>, %at: vector<4xf32>, %au: vector<4xf32>, %av: vector<4xf32>,
125-
%aw: vector<4xf32>, %ax: vector<4xf32>, %ay: vector<4xf32>, %az: vector<4xf32>,
126-
%ba: vector<4xf32>, %bb: vector<4xf32>, %bc: vector<4xf32>, %bd: vector<4xf32>,
127-
%be: vector<4xf32>, %bf: vector<4xf32>, %bg: vector<4xf32>, %bh: vector<4xf32>,
128-
%bi: vector<4xf32>, %bj: vector<4xf32>, %bk: vector<4xf32>, %bl: vector<4xf32>) -> vector<256xf32> {
120+
func.func @shuffle_tree_concat_64x4_256(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>, %d: vector<4xf32>,
121+
%e: vector<4xf32>, %f: vector<4xf32>, %g: vector<4xf32>, %h: vector<4xf32>,
122+
%i: vector<4xf32>, %j: vector<4xf32>, %k: vector<4xf32>, %l: vector<4xf32>,
123+
%m: vector<4xf32>, %n: vector<4xf32>, %o: vector<4xf32>, %p: vector<4xf32>,
124+
%q: vector<4xf32>, %r: vector<4xf32>, %s: vector<4xf32>, %t: vector<4xf32>,
125+
%u: vector<4xf32>, %v: vector<4xf32>, %w: vector<4xf32>, %x: vector<4xf32>,
126+
%y: vector<4xf32>, %z: vector<4xf32>, %aa: vector<4xf32>, %ab: vector<4xf32>,
127+
%ac: vector<4xf32>, %ad: vector<4xf32>, %ae: vector<4xf32>, %af: vector<4xf32>,
128+
%ag: vector<4xf32>, %ah: vector<4xf32>, %ai: vector<4xf32>, %aj: vector<4xf32>,
129+
%ak: vector<4xf32>, %al: vector<4xf32>, %am: vector<4xf32>, %an: vector<4xf32>,
130+
%ao: vector<4xf32>, %ap: vector<4xf32>, %aq: vector<4xf32>, %ar: vector<4xf32>,
131+
%as: vector<4xf32>, %at: vector<4xf32>, %au: vector<4xf32>, %av: vector<4xf32>,
132+
%aw: vector<4xf32>, %ax: vector<4xf32>, %ay: vector<4xf32>, %az: vector<4xf32>,
133+
%ba: vector<4xf32>, %bb: vector<4xf32>, %bc: vector<4xf32>, %bd: vector<4xf32>,
134+
%be: vector<4xf32>, %bf: vector<4xf32>, %bg: vector<4xf32>, %bh: vector<4xf32>,
135+
%bi: vector<4xf32>, %bj: vector<4xf32>, %bk: vector<4xf32>, %bl: vector<4xf32>) -> vector<256xf32> {
129136
%0:4 = vector.to_elements %a : vector<4xf32>
130137
%1:4 = vector.to_elements %b : vector<4xf32>
131138
%2:4 = vector.to_elements %c : vector<4xf32>
@@ -276,9 +283,9 @@ func.func @shuffle_tree_concat_64x4_256(
276283
// -----
277284

278285
func.func @shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
279-
%b: vector<4xf32>,
280-
%c: vector<4xf32>,
281-
%d: vector<4xf32>) -> vector<16xf32> {
286+
%b: vector<4xf32>,
287+
%c: vector<4xf32>,
288+
%d: vector<4xf32>) -> vector<16xf32> {
282289
%0:4 = vector.to_elements %a : vector<4xf32>
283290
%1:4 = vector.to_elements %b : vector<4xf32>
284291
%2:4 = vector.to_elements %c : vector<4xf32>
@@ -299,8 +306,8 @@ func.func @shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
299306
// -----
300307

301308
func.func @shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
302-
%b: vector<4xf32>,
303-
%c: vector<4xf32>) -> vector<12xf32> {
309+
%b: vector<4xf32>,
310+
%c: vector<4xf32>) -> vector<12xf32> {
304311
%0:4 = vector.to_elements %a : vector<4xf32>
305312
%1:4 = vector.to_elements %b : vector<4xf32>
306313
%2:4 = vector.to_elements %c : vector<4xf32>
@@ -320,8 +327,8 @@ func.func @shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
320327
// -----
321328

322329
func.func @shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
323-
%b: vector<5xf32>,
324-
%c: vector<5xf32>) -> vector<9xf32> {
330+
%b: vector<5xf32>,
331+
%c: vector<5xf32>) -> vector<9xf32> {
325332
%0:5 = vector.to_elements %a : vector<5xf32>
326333
%1:5 = vector.to_elements %b : vector<5xf32>
327334
%2:5 = vector.to_elements %c : vector<5xf32>
@@ -341,9 +348,9 @@ func.func @shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
341348
// -----
342349

343350
func.func @shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>,
344-
%b: vector<2xf32>,
345-
%c: vector<2xf32>,
346-
%d: vector<2xf32>) -> vector<32xf32> {
351+
%b: vector<2xf32>,
352+
%c: vector<2xf32>,
353+
%d: vector<2xf32>) -> vector<32xf32> {
347354
%0:2 = vector.to_elements %a : vector<2xf32>
348355
%1:2 = vector.to_elements %b : vector<2xf32>
349356
%2:2 = vector.to_elements %c : vector<2xf32>
@@ -364,12 +371,11 @@ func.func @shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>,
364371

365372
// -----
366373

367-
func.func @shuffle_tree_arbitrary_mixed_sizes(
368-
%a : vector<2xf32>,
369-
%b : vector<1xf32>,
370-
%c : vector<3xf32>,
371-
%d : vector<1xf32>,
372-
%e : vector<5xf32>) -> vector<6xf32> {
374+
func.func @shuffle_tree_arbitrary_mixed_sizes(%a : vector<2xf32>,
375+
%b : vector<1xf32>,
376+
%c : vector<3xf32>,
377+
%d : vector<1xf32>,
378+
%e : vector<5xf32>) -> vector<6xf32> {
373379
%0:2 = vector.to_elements %a : vector<2xf32>
374380
%1 = vector.to_elements %b : vector<1xf32>
375381
%2:3 = vector.to_elements %c : vector<3xf32>
@@ -391,13 +397,12 @@ func.func @shuffle_tree_arbitrary_mixed_sizes(
391397

392398
// -----
393399

394-
func.func @shuffle_tree_odd_intermediate_vectors(
395-
%a : vector<2xf32>,
396-
%b : vector<2xf32>,
397-
%c : vector<2xf32>,
398-
%d : vector<2xf32>,
399-
%e : vector<2xf32>,
400-
%f : vector<2xf32>) -> vector<6xf32> {
400+
func.func @shuffle_tree_odd_intermediate_vectors(%a : vector<2xf32>,
401+
%b : vector<2xf32>,
402+
%c : vector<2xf32>,
403+
%d : vector<2xf32>,
404+
%e : vector<2xf32>,
405+
%f : vector<2xf32>) -> vector<6xf32> {
401406
%0:2 = vector.to_elements %a : vector<2xf32>
402407
%1:2 = vector.to_elements %b : vector<2xf32>
403408
%2:2 = vector.to_elements %c : vector<2xf32>
@@ -417,7 +422,3 @@ func.func @shuffle_tree_odd_intermediate_vectors(
417422
// CHECK: %[[L2SH0:.*]] = vector.shuffle %[[L0SH2]], %[[L0SH2]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32>
418423
// CHECK: %[[L3SH0:.*]] = vector.shuffle %[[L1SH0]], %[[L2SH0]] [0, 1, 2, 3, 4, 5] : vector<4xf32>, vector<4xf32>
419424
// CHECK: return %[[L3SH0]] : vector<6xf32>
420-
421-
422-
423-

0 commit comments

Comments
 (0)