Skip to content

Commit 77f4c91

Browse files
committed
[mlir][vector] Allow transposing multi_reduction when the parallel dim is in the middle
The check for the outer lowering wasn't quite right. Differential Revision: https://reviews.llvm.org/D142483
1 parent 76790cf commit 77f4c91

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ class InnerOuterDimReductionConversion
7777
return failure();
7878

7979
if (!useInnerDimsForReduction &&
80-
(parallelDims !=
81-
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
80+
(parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
81+
reductionDims.size(),
82+
parallelDims.size() + reductionDims.size()))))
8283
return failure();
8384

8485
SmallVector<int64_t, 4> indices;

mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,13 @@ func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1
234234
// CHECK: %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
235235
// CHECK: %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]]
236236

237+
// -----
238+
239+
func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
240+
%0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
241+
return %0 : vector<4xf32>
242+
}
243+
244+
// CHECK-LABEL: func @vector_multi_reduction_parallel_middle
245+
// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
246+
// CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>

mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,15 @@ func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi
162162
// CHECK: %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32>
163163
// CHECK: return %[[RESULT_VEC]] : vector<2x3xi32>
164164

165+
func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
166+
%0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
167+
return %0 : vector<4xf32>
168+
}
169+
170+
// CHECK-LABEL: func @vector_multi_reduction_parallel_middle
171+
// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
172+
// CHECK: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
173+
165174
// This test is mainly to catch a bug that running
166175
// `InnerOuterDimReductionConversion` on this function results in an
167176
// infinite loop. So just check that some value is returned.

0 commit comments

Comments
 (0)