Skip to content

Commit 3c94c22

Browse files
committed
[mlir] Method to iterate over registered operations for a given dialect class.
Part of #111591 Currently we have `MLIRContext::getRegisteredOperations` which returns all operations for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the same for a given dialect class.
1 parent b2cac3b commit 3c94c22

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ class MLIRContext {
197197
/// operations.
198198
ArrayRef<RegisteredOperationName> getRegisteredOperations();
199199

200+
/// Return a sorted array containing the information for registered operations
201+
/// filtered by dialect name.
202+
ArrayRef<RegisteredOperationName>
203+
getRegisteredOperationsByDialect(StringRef dialectName);
204+
200205
/// Return true if this operation name is registered in this context.
201206
bool isOperationRegistered(StringRef name);
202207

mlir/lib/IR/MLIRContext.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ class MLIRContextImpl {
190190
/// and efficient `getRegisteredOperations` implementation.
191191
SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
192192

193+
/// This returns the number of registered operations for the given dialect.
194+
size_t operationCount = 0;
195+
193196
/// This is a list of dialects that are created referring to this context.
194197
/// The MLIRContext owns the objects. These need to be declared after the
195198
/// registered operations to ensure correct destruction order.
@@ -711,6 +714,27 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
711714
return impl->sortedRegisteredOperations;
712715
}
713716

717+
/// Return information for registered operations by dialect.
718+
ArrayRef<RegisteredOperationName>
719+
MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
720+
auto lowerBound =
721+
std::lower_bound(impl->sortedRegisteredOperations.begin(),
722+
impl->sortedRegisteredOperations.end(), dialectName,
723+
[](auto &lhs, auto &rhs) {
724+
return lhs.getDialect().getNamespace().compare(
725+
rhs.getDialect().getNamespace());
726+
});
727+
728+
auto upperBound =
729+
std::lower_bound(lowerBound, impl->sortedRegisteredOperations.end(),
730+
dialectName, [](auto &lhs, auto &rhs) {
731+
return lhs.getDialect().getNamespace().compare(
732+
rhs.getDialect().getNamespace());
733+
});
734+
size_t count = std::distance(lowerBound, upperBound);
735+
return ArrayRef(lowerBound, count);
736+
}
737+
714738
bool MLIRContext::isOperationRegistered(StringRef name) {
715739
return RegisteredOperationName::lookup(name, this).has_value();
716740
}
@@ -976,6 +1000,8 @@ void RegisteredOperationName::insert(
9761000
"operation name registration must be successful");
9771001

9781002
// Add emplaced operation name to the sorted operations container.
1003+
ctxImpl.operationCount += 1;
1004+
9791005
RegisteredOperationName &value = emplaced.first->second;
9801006
ctxImpl.sortedRegisteredOperations.insert(
9811007
llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,

0 commit comments

Comments
 (0)