Skip to content

Commit 2f09841

Browse files
IanWood1Groverkss
authored andcommitted
[Util][NFC] OptimizeIntArithmetic: reduce calls to eraseState (iree-org#19130)
This pass is causing long compilation times for llama3 405b (even when cherry-picking llvm/llvm-project#115399). The majority of the time is spent in this one pass. The compilation times improve when calling `eraseState` only when ops are deleted. This is similar to the upstream listeners in `UnsignedWhenEquivalent.cpp` and `IntRangeOptimizations.cpp`. It appears this function loops over all `LatticeAnchors` on each invocation to find the one to delete, causing it to be slow. My (nonrigorous) experiment showed a decrease from 18 min to 3 min compile time. My main concern here would be this affecting correctness, as I don't know if this has unaccounted for side effects. Signed-off-by: Ian Wood <[email protected]>
1 parent e2f9bb2 commit 2f09841

File tree

1 file changed

+5
-38
lines changed

1 file changed

+5
-38
lines changed

compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "mlir/Pass/PassRegistry.h"
2424
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2525

26-
#define DEBUG_TYPE "iree-util-optimize-arithmetic"
26+
#define DEBUG_TYPE "iree-util-optimize-int-arithmetic"
2727
using llvm::dbgs;
2828

2929
using namespace mlir::dataflow;
@@ -289,43 +289,7 @@ class DataFlowListener : public RewriterBase::Listener {
289289
void notifyOperationErased(Operation *op) override {
290290
s.eraseState(s.getProgramPointAfter(op));
291291
for (Value res : op->getResults())
292-
flushValue(res);
293-
}
294-
void notifyOperationModified(Operation *op) override {
295-
for (Value res : op->getResults())
296-
flushValue(res);
297-
}
298-
void notifyOperationReplaced(Operation *op, Operation *replacement) override {
299-
for (Value res : op->getResults())
300-
flushValue(res);
301-
}
302-
303-
void notifyOperationReplaced(Operation *op, ValueRange replacement) override {
304-
for (Value res : op->getResults())
305-
flushValue(res);
306-
}
307-
308-
void flushValue(Value value) {
309-
SmallVector<Value> worklist;
310-
SmallVector<Value> process;
311-
worklist.push_back(value);
312-
313-
while (!worklist.empty()) {
314-
process.clear();
315-
process.swap(worklist);
316-
for (Value childValue : process) {
317-
auto *state = s.lookupState<IntegerValueRangeLattice>(childValue);
318-
if (!state) {
319-
continue;
320-
}
321-
s.eraseState(childValue);
322-
for (auto user : childValue.getUsers()) {
323-
for (Value result : user->getResults()) {
324-
worklist.push_back(result);
325-
}
326-
}
327-
}
328-
}
292+
s.eraseState(res);
329293
}
330294

331295
DataFlowSolver &s;
@@ -386,11 +350,14 @@ class OptimizeIntArithmeticPass
386350

387351
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
388352
for (int i = 0;; ++i) {
353+
LLVM_DEBUG(dbgs() << " * Starting iteration: " << i << "\n");
389354
if (failed(solver.initializeAndRun(op))) {
390355
emitError(op->getLoc()) << "failed to perform int range analysis";
391356
return signalPassFailure();
392357
}
393358

359+
LLVM_DEBUG(
360+
dbgs() << " * Finished Running Solver -- Applying Patterns\n");
394361
bool changed = false;
395362
if (failed(applyPatternsAndFoldGreedily(op, frozenPatterns, config,
396363
&changed))) {

0 commit comments

Comments
 (0)