Skip to content

Commit 782fdfa

Browse files
author
Pavel Lipskiy
committed
[mlir][linalg] Add pattern to clean unused results after fusion
In some cases, elementwise fusion can produce ops with multiple results, but only one of them is used in the IR. This makes the IR less readable and prevents additional fusions from being triggered. This patch adds the `DropRedundantResultsFromGenericOps` pattern to find these outputs and convert them into inputs. Signed-off-by: Pavel Lipskiy <[email protected]>
1 parent 471bd17 commit 782fdfa

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,6 +2200,63 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
22002200
}
22012201
};
22022202

2203+
/// Drops an unused result from an elementwise `linalg.generic` by
2204+
/// reclassifying its tied `outs` operand as an extra input operand.
2205+
struct DropRedundantResultsFromGenericOps
2206+
: public OpRewritePattern<linalg::GenericOp> {
2207+
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
2208+
LogicalResult matchAndRewrite(linalg::GenericOp op,
2209+
PatternRewriter &rewriter) const override {
2210+
if (!linalg::isElementwise(op) || op.getNumResults() < 2U)
2211+
return failure();
2212+
2213+
// Given that the op has no reductions, there is no need to preserve an
2214+
// unused result: transform it into an input instead.
2215+
auto maybeUnusedRes = llvm::find_if(
2216+
op.getResults(), [](OpResult res) { return res.use_empty(); });
2217+
if (maybeUnusedRes == op.getResults().end())
2218+
return failure();
2219+
2220+
OpResult unusedRes = *maybeUnusedRes;
2221+
const unsigned resIdx = unusedRes.getResultNumber();
2222+
auto resTypes = llvm::to_vector(op.getResultTypes());
2223+
resTypes.erase(resTypes.begin() + resIdx);
2224+
SmallVector<Value> resValues = llvm::to_vector_of<Value>(op.getResults());
2225+
resValues.erase(resValues.begin() + resIdx);
2226+
const int64_t numInputs = op.getNumDpsInputs();
2227+
OpOperand *resOperand = op.getTiedOpOperand(unusedRes);
2228+
AffineMap map = op.getIndexingMapMatchingResult(unusedRes);
2229+
const unsigned operandIdx = resOperand->getOperandNumber();
2230+
2231+
// Remove the output operand and add it as an input operand with the same
2232+
// map.
2233+
SmallVector<Value> outs(op.getOutputs());
2234+
outs.erase(outs.begin() + resIdx);
2235+
SmallVector<Value> ins(op.getInputs());
2236+
ins.insert(ins.begin() + numInputs, resOperand->get());
2237+
SmallVector<AffineMap> maps = op.getIndexingMapsArray();
2238+
maps.erase(maps.begin() + operandIdx);
2239+
maps.insert(maps.begin() + numInputs, map);
2240+
rewriter.setInsertionPoint(op);
2241+
2242+
auto newGenericOp = rewriter.create<linalg::GenericOp>(
2243+
op.getLoc(), TypeRange(resTypes), ins, outs, maps,
2244+
op.getIteratorTypesArray());
2245+
2246+
op->setDiscardableAttrs(op->getDiscardableAttrDictionary());
2247+
op.getBody()->getTerminator()->eraseOperands(resIdx);
2248+
newGenericOp.getRegion().takeBody(op.getBodyRegion());
2249+
2250+
// Replace the remaining results of the old op with the results of the new
2251+
// op.
2252+
rewriter.replaceAllUsesWith(resValues, newGenericOp.getResults());
2253+
2254+
// Remove the old op.
2255+
rewriter.eraseOp(op);
2256+
return success();
2257+
}
2258+
};
2259+
22032260
/// Fold linalg.fill into linalg.generic
22042261
struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
22052262
using OpRewritePattern<GenericOp>::OpRewritePattern;
@@ -2262,6 +2319,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
22622319
RemoveOutsDependency>(context);
22632320
// Add the patterns that clean up dead operands and results.
22642321
populateEraseUnusedOperandsAndResultsPatterns(patterns);
2322+
patterns.add<DropRedundantResultsFromGenericOps>(context);
22652323
}
22662324

22672325
void mlir::linalg::populateCollapseDimensions(

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,4 +1079,49 @@ module {
10791079
// CHECK-NOT: linalg.generic
10801080
// CHECK: tensor.expand_shape
10811081
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
1082-
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
1082+
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
1083+
1084+
// -----
1085+
1086+
// CHECK-LABEL: @drop_unused_results
1087+
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]: tensor<64xf32>, [[ARG1:%[a-zA-Z0-9]+]]: tensor<1x56x56x64xf32>
1088+
func.func @drop_unused_results(%arg0: tensor<64xf32>, %arg1: tensor<1x56x56x64xf32>) -> tensor<1x56x56x64xf32> {
1089+
%cst = arith.constant 3.40282347E+38 : f32
1090+
%cst_0 = arith.constant 0.000000e+00 : f32
1091+
// CHECK: [[OUT:%[a-zA-Z0-9]+]] = tensor.empty() : tensor<1x56x56x64xf32>
1092+
%0 = tensor.empty() : tensor<1x56x56x64xf32>
1093+
// CHECK: [[RES:%[0-9]+]] = linalg.generic {{.*}} ins([[ARG0]], [[ARG1]] : tensor<64xf32>, tensor<1x56x56x64xf32>) outs([[OUT]] : tensor<1x56x56x64xf32>)
1094+
%1:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<64xf32>) outs(%arg1, %0 : tensor<1x56x56x64xf32>, tensor<1x56x56x64xf32>) {
1095+
^bb0(%in: f32, %out: f32, %out_1: f32):
1096+
%2 = arith.addf %in, %out : f32
1097+
%3 = arith.minimumf %2, %cst : f32
1098+
%4 = arith.maximumf %3, %cst_0 : f32
1099+
linalg.yield %2, %4 : f32, f32
1100+
} -> (tensor<1x56x56x64xf32>, tensor<1x56x56x64xf32>)
1101+
// CHECK: -> tensor<1x56x56x64xf32>
1102+
// CHECK: return [[RES]] : tensor<1x56x56x64xf32>
1103+
return %1#1 : tensor<1x56x56x64xf32>
1104+
}
1105+
1106+
// -----
1107+
1108+
// CHECK-LABEL: @swap_drop_unused_results
1109+
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]: tensor<64xf32>, [[ARG1:%[a-zA-Z0-9]+]]: tensor<1x56x56x64xf32>
1110+
func.func @swap_drop_unused_results(%arg0: tensor<64xf32>, %arg1: tensor<1x56x56x64xf32>) -> tensor<1x56x56x64xf32> {
1111+
%cst = arith.constant 3.40282347E+38 : f32
1112+
%cst_0 = arith.constant 0.000000e+00 : f32
1113+
// CHECK: [[OUT:%[a-zA-Z0-9]+]] = tensor.empty() : tensor<1x56x56x64xf32>
1114+
%0 = tensor.empty() : tensor<1x56x56x64xf32>
1115+
// CHECK: [[RES:%[0-9]+]] = linalg.generic {{.*}} ins([[ARG0]] : tensor<64xf32>) outs([[OUT]] : tensor<1x56x56x64xf32>)
1116+
%1:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<64xf32>) outs(%arg1, %0 : tensor<1x56x56x64xf32>, tensor<1x56x56x64xf32>) {
1117+
^bb0(%in: f32, %out_1: f32, %out: f32):
1118+
%2 = arith.addf %in, %out : f32
1119+
%3 = arith.minimumf %2, %cst : f32
1120+
%4 = arith.maximumf %3, %cst_0 : f32
1121+
linalg.yield %2, %4 : f32, f32
1122+
} -> (tensor<1x56x56x64xf32>, tensor<1x56x56x64xf32>)
1123+
// CHECK: -> tensor<1x56x56x64xf32>
1124+
// CHECK: return [[RES]] : tensor<1x56x56x64xf32>
1125+
return %1#0 : tensor<1x56x56x64xf32>
1126+
}
1127+

0 commit comments

Comments
 (0)