Skip to content

Commit dddf6ab

Browse files
committed
Simplifying the SplitReduction logic that uses the control to get the
dimension where the extra parallel dimension is inserted Currently, the innerParallel and non innerParallel strategies use two different ways to fix for where the extra loop is inserted and where the extra dimension for the intermediate result is inserted - innerParallel adds the extra (parallel) loop right after the pre-existing reduction loop, whereas non innerParallel adds the reduction loop in the successor to the index supplied by control, and the parallel loop in the index supplied by the control. The semantics of the index supplied by the control is supposed to only control where the extra tensor dimension is inserted in the intermediate tensor. Conflating this index with where the reduction (and parallel) loops are inserted leads to more complex (and confusing) logic overall. This differential removes conflating the two uses of the index, and keeps the reduction and parallel loops in the same vicinity and uses the supplied index to only determine the position of the extra tensor dimension. It also simplifies the code by merging the two strategies in a lot more places. Differential Revision: https://reviews.llvm.org/D137478
1 parent 7665369 commit dddf6ab

File tree

2 files changed

+24
-38
lines changed

2 files changed

+24
-38
lines changed

mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
3434

3535
SplitReductionOptions control = controlSplitReductionFn(op);
3636
int64_t ratio = control.ratio;
37-
unsigned insertSplitDimension = control.index;
37+
unsigned insertSplitIndex = control.index;
3838
if (ratio <= 1)
3939
return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
4040

@@ -45,10 +45,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
4545
SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
4646
int64_t reductionDimSize = loopRanges[reductionDim];
4747
if (reductionDimSize == ShapedType::kDynamicSize ||
48-
reductionDimSize % ratio != 0 ||
49-
insertSplitDimension >= loopRanges.size())
48+
reductionDimSize % ratio != 0)
5049
return b.notifyMatchFailure(
5150
op, "Reduction dimension not divisible by split ratio");
51+
if (op.getNumDpsInits() != 1)
52+
return b.notifyMatchFailure(op, "More than one output in split reduction");
53+
if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
54+
return b.notifyMatchFailure(op, "Insert dimension position too large "
55+
"compared to intermediate tensor size");
5256

5357
SmallVector<Operation *, 4> combinerOps;
5458
if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
@@ -80,25 +84,13 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
8084
newShape.push_back(ratio);
8185
newShape.push_back(op.getShape(operand)[idx] / ratio);
8286
}
87+
exprs.push_back(b.getAffineDimExpr(reductionDim));
88+
exprs.push_back(b.getAffineDimExpr(reductionDim + 1));
8389
reassociation.push_back({index++, index++});
84-
if (control.innerParallel) {
85-
exprs.push_back(b.getAffineDimExpr(reductionDim));
86-
exprs.push_back(b.getAffineDimExpr(reductionDim + 1));
87-
} else {
88-
exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
89-
exprs.push_back(
90-
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
91-
}
9290
continue;
9391
}
9492
newShape.push_back(op.getShape(operand)[idx]);
95-
if (control.innerParallel) {
96-
exprs.push_back(
97-
b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1));
98-
} else {
99-
exprs.push_back(
100-
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
101-
}
93+
exprs.push_back(b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1));
10294
reassociation.push_back({index++});
10395
}
10496
newMaps.push_back(
@@ -122,26 +114,20 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
122114
AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
123115
ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
124116
SmallVector<AffineExpr> outputExpr;
125-
for (unsigned idx :
126-
llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
127-
if (idx == insertSplitDimension) {
117+
for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
118+
if (insertSplitIndex == idx) {
128119
newOutputShape.push_back(ratio);
129120
if (control.innerParallel) {
130121
outputExpr.push_back(b.getAffineDimExpr(reductionDim + 1));
131122
} else {
132-
outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
123+
outputExpr.push_back(b.getAffineDimExpr(reductionDim));
133124
}
134-
continue;
135125
}
136-
unsigned oldIdx = idx < insertSplitDimension ? idx : idx - 1;
137-
newOutputShape.push_back(oldShape[oldIdx]);
138-
unsigned dim = oldOutputMap.getDimPosition(oldIdx);
139-
if (control.innerParallel) {
140-
outputExpr.push_back(
141-
b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1));
142-
} else {
126+
if (idx < oldShape.size()) {
127+
newOutputShape.push_back(oldShape[idx]);
128+
unsigned dim = oldOutputMap.getDimPosition(idx);
143129
outputExpr.push_back(
144-
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
130+
b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1));
145131
}
146132
}
147133
Value emptyOrAllocTensor;
@@ -164,10 +150,10 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
164150
op.getContext()));
165151
SmallVector<utils::IteratorType> newIteratorTypes;
166152
for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) {
167-
if (insertSplitDimension == it.index() && !control.innerParallel)
153+
if (reductionDim == it.index() && !control.innerParallel)
168154
newIteratorTypes.push_back(utils::IteratorType::parallel);
169155
newIteratorTypes.push_back(it.value());
170-
if (insertSplitDimension == it.index() && control.innerParallel)
156+
if (reductionDim == it.index() && control.innerParallel)
171157
newIteratorTypes.push_back(utils::IteratorType::parallel);
172158
}
173159
// Create the new op matching the original op with an extra parallel
@@ -185,7 +171,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
185171
SmallVector<utils::IteratorType> reductionIteratorTypes;
186172
SmallVector<AffineExpr> exprs;
187173
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
188-
if (insertSplitDimension == i) {
174+
if (insertSplitIndex == i) {
189175
reductionIteratorTypes.push_back(utils::IteratorType::reduction);
190176
} else {
191177
exprs.push_back(b.getAffineDimExpr(i));

mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
106106
return %0 : tensor<5x2xf32>
107107
}
108108

109-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
110-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
111-
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
109+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)>
110+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>
111+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1)>
112112
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
113113
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
114114
// CHECK-LABEL: func @generic_split_3d
@@ -117,7 +117,7 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
117117
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
118118
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
119119
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
120-
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
120+
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
121121
// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
122122
// CHECK: arith.addf
123123
// CHECK: arith.maxf

0 commit comments

Comments
 (0)