Skip to content

Commit 02bc4c9

Browse files
committed
[mlir][PassManager] Only reinitialize the pass manager if the context registry changes
This prevents needless reinitialization for clients that want to reuse a pass manager multiple times. A new `getRegisryHash` function is exposed by the context to give a rough indicator of when the context registry has changed. Differential Revision: https://reviews.llvm.org/D95493
1 parent c3df9d5 commit 02bc4c9

File tree

4 files changed

+25
-1
lines changed

4 files changed

+25
-1
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ class MLIRContext {
166166
Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
167167
function_ref<std::unique_ptr<Dialect>()> ctor);
168168

169+
/// Returns a hash of the registry of the context that may be used to give
170+
/// a rough indicator of if the state of the context registry has changed. The
171+
/// context registry correlates to loaded dialects and their entities
172+
/// (attributes, operations, types, etc.).
173+
llvm::hash_code getRegistryHash();
174+
169175
private:
170176
const std::unique_ptr<MLIRContextImpl> impl;
171177

mlir/include/mlir/Pass/PassManager.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ class PassManager : public OpPassManager {
375375
/// An optional factory to use when generating a crash reproducer if valid.
376376
ReproducerStreamFactory crashReproducerStreamFactory;
377377

378+
/// A hash key used to detect when reinitialization is necessary.
379+
llvm::hash_code initializationKey;
380+
378381
/// Flag that specifies if pass timing is enabled.
379382
bool passTiming : 1;
380383

mlir/lib/IR/MLIRContext.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,16 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
492492
return dialect.get();
493493
}
494494

495+
llvm::hash_code MLIRContext::getRegistryHash() {
496+
llvm::hash_code hash(0);
497+
// Factor in number of loaded dialects, attributes, operations, types.
498+
hash = llvm::hash_combine(hash, impl->loadedDialects.size());
499+
hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
500+
hash = llvm::hash_combine(hash, impl->registeredOperations.size());
501+
hash = llvm::hash_combine(hash, impl->registeredTypes.size());
502+
return hash;
503+
}
504+
495505
bool MLIRContext::allowsUnregisteredDialects() {
496506
return impl->allowUnregisteredDialects;
497507
}

mlir/lib/Pass/Pass.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
846846
PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
847847
StringRef operationName)
848848
: OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
849+
initializationKey(DenseMapInfo<llvm::hash_code>::getTombstoneKey()),
849850
passTiming(false), localReproducer(false), verifyPasses(true) {}
850851

851852
PassManager::~PassManager() {}
@@ -868,7 +869,11 @@ LogicalResult PassManager::run(Operation *op) {
868869
dependentDialects.loadAll(context);
869870

870871
// Initialize all of the passes within the pass manager with a new generation.
871-
initialize(context, impl->initializationGeneration + 1);
872+
llvm::hash_code newInitKey = context->getRegistryHash();
873+
if (newInitKey != initializationKey) {
874+
initialize(context, impl->initializationGeneration + 1);
875+
initializationKey = newInitKey;
876+
}
872877

873878
// Construct a top level analysis manager for the pipeline.
874879
ModuleAnalysisManager am(op, instrumentor.get());

0 commit comments

Comments
 (0)