diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index d17bbac81655b..ef8dab87f131a 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -197,6 +197,11 @@ class MLIRContext { /// operations. ArrayRef getRegisteredOperations(); + /// Return a sorted array containing the information for registered operations + /// filtered by dialect name. + ArrayRef + getRegisteredOperationsByDialect(StringRef dialectName); + /// Return true if this operation name is registered in this context. bool isOperationRegistered(StringRef name); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index f05666fcde207..d33340f4aefc8 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -711,6 +711,30 @@ ArrayRef MLIRContext::getRegisteredOperations() { return impl->sortedRegisteredOperations; } +/// Return information for registered operations by dialect. +ArrayRef +MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) { + auto lowerBound = + std::lower_bound(impl->sortedRegisteredOperations.begin(), + impl->sortedRegisteredOperations.end(), dialectName, + [](auto &lhs, auto &rhs) { + return lhs.getDialect().getNamespace().compare(rhs); + }); + + if (lowerBound == impl->sortedRegisteredOperations.end() || + lowerBound->getDialect().getNamespace() != dialectName) + return ArrayRef(); + + auto upperBound = + std::upper_bound(lowerBound, impl->sortedRegisteredOperations.end(), + dialectName, [](auto &lhs, auto &rhs) { + return lhs.compare(rhs.getDialect().getNamespace()); + }); + + size_t count = std::distance(lowerBound, upperBound); + return ArrayRef(&*lowerBound, count); +} + bool MLIRContext::isOperationRegistered(StringRef name) { return RegisteredOperationName::lookup(name, this).has_value(); }