Skip to content

Commit 70b35e9

Browse files
committed
[mlir] Method to iterate over registered operations for a given dialect class.
Part of #111591 Currently we have `MLIRContext::getRegisteredOperations` which returns all operations for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the same for a given dialect class.
1 parent b2cac3b commit 70b35e9

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ class MLIRContext {
197197
/// operations.
198198
ArrayRef<RegisteredOperationName> getRegisteredOperations();
199199

200+
/// Return a sorted array containing the information for registered operations
201+
/// filtered by dialect name.
202+
ArrayRef<RegisteredOperationName>
203+
getRegisteredOperationsByDialect(StringRef dialectName);
204+
200205
/// Return true if this operation name is registered in this context.
201206
bool isOperationRegistered(StringRef name);
202207

mlir/lib/IR/MLIRContext.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,32 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
711711
return impl->sortedRegisteredOperations;
712712
}
713713

714+
/// Return information for registered operations by dialect.
715+
ArrayRef<RegisteredOperationName>
716+
MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
717+
auto lowerBound =
718+
std::lower_bound(impl->sortedRegisteredOperations.begin(),
719+
impl->sortedRegisteredOperations.end(), dialectName,
720+
[](auto &lhs, auto &rhs) {
721+
return lhs.getDialect().getNamespace().compare(
722+
rhs.getDialect().getNamespace());
723+
});
724+
725+
if (lowerBound == impl->sortedRegisteredOperations.end() ||
726+
lowerBound->getDialect().getNamespace() != dialectName)
727+
return ArrayRef<RegisteredOperationName>();
728+
729+
auto upperBound =
730+
std::upper_bound(lowerBound, impl->sortedRegisteredOperations.end(),
731+
dialectName, [](auto &lhs, auto &rhs) {
732+
return lhs.getDialect().getNamespace().compare(
733+
rhs.getDialect().getNamespace());
734+
});
735+
736+
size_t count = std::distance(lowerBound, upperBound) - 1;
737+
return ArrayRef(lowerBound, count);
738+
}
739+
714740
bool MLIRContext::isOperationRegistered(StringRef name) {
715741
return RegisteredOperationName::lookup(name, this).has_value();
716742
}

0 commit comments

Comments
 (0)