Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion mlir/include/mlir/Transforms/CSE.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,44 @@
#ifndef MLIR_TRANSFORMS_CSE_H_
#define MLIR_TRANSFORMS_CSE_H_

#include <functional>

namespace mlir {

class DominanceInfo;
class Operation;
class RewriterBase;

/// Configuration for CSE.
struct CSEConfig {
/// If set, matching ops act as a CSE'ing barrier: ops are not CSE'd across
/// matching ops.
///
/// Note: IsolatedFromAbove ops are always a CSE'ing barrier, regardless of
/// this filter.
///
/// Example:
/// %0 = arith.constant 0 : index
/// scf.for ... {
/// %1 = arith.constant 0 : index
/// ...
/// }
/// If "scf.for" is marked as a CSE'ing barrier, %0 and %1 are *not* CSE'd.
std::function<bool(Operation *)> barrierOpFilter = nullptr;

/// If set, matching ops are not eliminated (neither CSE'd nor DCE'd). All
/// non-matching ops are subject to elimination.
std::function<bool(Operation *)> eliminateOpFilter = nullptr;
};

/// Eliminate common subexpressions within the given operation. This transform
/// looks for and deduplicates equivalent operations.
///
/// `changed` indicates whether the IR was modified or not.
void eliminateCommonSubExpressions(RewriterBase &rewriter,
DominanceInfo &domInfo, Operation *op,
bool *changed = nullptr);
bool *changed = nullptr,
CSEConfig config = CSEConfig());

} // namespace mlir

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class GreedyRewriteConfig;

#define GEN_PASS_DECL_CANONICALIZER
#define GEN_PASS_DECL_CONTROLFLOWSINK
#define GEN_PASS_DECL_CSEPASS
#define GEN_PASS_DECL_CSE
#define GEN_PASS_DECL_INLINER
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
#define GEN_PASS_DECL_MEM2REG
Expand Down
17 changes: 15 additions & 2 deletions mlir/include/mlir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,25 @@ def CSE : Pass<"cse"> {
let summary = "Eliminate common sub-expressions";
let description = [{
This pass implements a generalized algorithm for common sub-expression
elimination. This pass relies on information provided by the
`Memory SideEffect` interface to identify when it is safe to eliminate
elimination. The pass also eliminates dead operation (DCE). The pass
relies on information provided by the `MemoryEffectOpInterface`
interface and on `DominanceInfo` to identify when it is safe to eliminate
operations. See [Common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination)
for more general details on this optimization.

The types of ops that are subject to elimination can be configured with
`eliminate-op-filter`. If set, only those ops are CSE'd or DCE'd.

Ops are never CSE'd across IsolatedFromAbove ops. Additional CSE'ing
barrier ops can be specified with `barrier-op-filter`.
}];
let constructor = "mlir::createCSEPass()";
let options = [
ListOption<"barrierOpFilter", "barrier-op-filter", "std::string",
"Names of ops that act as CSE'ing barriers">,
ListOption<"eliminateOpFilter", "eliminate-op-filter", "std::string",
"If non-empty, list of ops that are subject to elimination">,
];
let statistics = [
Statistic<"numCSE", "num-cse'd", "Number of operations CSE'd">,
Statistic<"numDCE", "num-dce'd", "Number of operations DCE'd">
Expand Down
47 changes: 39 additions & 8 deletions mlir/lib/Transforms/CSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/RecyclingAllocator.h"
#include <deque>

#include <deque>
#include <unordered_set>
namespace mlir {
#define GEN_PASS_DEF_CSE
#include "mlir/Transforms/Passes.h.inc"
Expand Down Expand Up @@ -60,8 +61,9 @@ namespace {
/// Simple common sub-expression elimination.
class CSEDriver {
public:
CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
: rewriter(rewriter), domInfo(domInfo) {}
CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo,
const CSEConfig &config)
: rewriter(rewriter), domInfo(domInfo), config(config) {}

/// Simplify all operations within the given op.
void simplify(Operation *op, bool *changed = nullptr);
Expand Down Expand Up @@ -125,6 +127,9 @@ class CSEDriver {
// Various statistics.
int64_t numCSE = 0;
int64_t numDCE = 0;

/// CSE configuration.
CSEConfig config;
};
} // namespace

Expand Down Expand Up @@ -226,6 +231,10 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
Operation *op,
bool hasSSADominance) {
// Don't simplify operations that are filtered out.
if (config.eliminateOpFilter && !config.eliminateOpFilter(op))
return failure();

// Don't simplify terminator operations.
if (op->hasTrait<OpTrait::IsTerminator>())
return failure();
Expand Down Expand Up @@ -288,8 +297,11 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
if (op.getNumRegions() != 0) {
// If this operation is isolated above, we can't process nested regions
// with the given 'knownValues' map. This would cause the insertion of
// implicit captures in explicit capture only regions.
if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
// implicit captures in explicit capture only regions. Additional barrier
// ops can be specified by the user.
bool isBarrier = op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
(config.barrierOpFilter && config.barrierOpFilter(&op));
if (isBarrier) {
ScopedMapTy nestedKnownValues;
for (auto &region : op.getRegions())
simplifyRegion(nestedKnownValues, region);
Expand Down Expand Up @@ -381,8 +393,8 @@ void CSEDriver::simplify(Operation *op, bool *changed) {

void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
DominanceInfo &domInfo, Operation *op,
bool *changed) {
CSEDriver driver(rewriter, &domInfo);
bool *changed, CSEConfig config) {
CSEDriver driver(rewriter, &domInfo, config);
driver.simplify(op, changed);
}

Expand All @@ -394,9 +406,28 @@ struct CSE : public impl::CSEBase<CSE> {
} // namespace

void CSE::runOnOperation() {
// Set up CSE configuration from pass options.
CSEConfig config;
std::unordered_set<std::string> barrierOpNames;
for (std::string opName : barrierOpFilter)
barrierOpNames.insert(opName);
std::unordered_set<std::string> eliminateOpNames;
for (std::string opName : eliminateOpFilter)
eliminateOpNames.insert(opName);
if (!barrierOpNames.empty()) {
config.barrierOpFilter = [&](Operation *op) -> bool {
return barrierOpNames.count(op->getName().getStringRef().str());
};
}
if (!eliminateOpNames.empty()) {
config.eliminateOpFilter = [&](Operation *op) -> bool {
return eliminateOpNames.count(op->getName().getStringRef().str());
};
}

// Simplify the IR.
IRRewriter rewriter(&getContext());
CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>(), config);
bool changed = false;
driver.simplify(getOperation(), &changed);

Expand Down
Loading
Loading