Skip to content

Commit d123ad6

Browse files
committed
[mlir] Method to iterate over registered operations for a given dialect class.
Part of #111591 Currently we have `MLIRContext::getRegisteredOperations` which returns all operations for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the same for a given dialect class.
1 parent b2cac3b commit d123ad6

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ class MLIRContext {
197197
/// operations.
198198
ArrayRef<RegisteredOperationName> getRegisteredOperations();
199199

200+
/// Return a sorted array containing the information for registered operations
201+
/// filtered by dialect name.
202+
ArrayRef<RegisteredOperationName>
203+
getRegisteredOperationsByDialect(StringRef dialectName);
204+
200205
/// Return true if this operation name is registered in this context.
201206
bool isOperationRegistered(StringRef name);
202207

mlir/lib/IR/MLIRContext.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ class MLIRContextImpl {
190190
/// and efficient `getRegisteredOperations` implementation.
191191
SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
192192

193+
/// This returns the number of registered operations for the given dialect.
194+
size_t operationCount = 0;
195+
193196
/// This is a list of dialects that are created referring to this context.
194197
/// The MLIRContext owns the objects. These need to be declared after the
195198
/// registered operations to ensure correct destruction order.
@@ -711,6 +714,26 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
711714
return impl->sortedRegisteredOperations;
712715
}
713716

717+
/// Return information for registered operations by dialect.
718+
ArrayRef<RegisteredOperationName>
719+
MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
720+
auto lowerBound =
721+
std::lower_bound(impl->sortedRegisteredOperations.begin(),
722+
impl->sortedRegisteredOperations.end(), dialectName,
723+
[](auto &lhs, auto &rhs) {
724+
return lhs.getDialect().getNamespace().compare(
725+
rhs.getDialect().getNamespace());
726+
});
727+
728+
if (lowerBound == impl->sortedRegisteredOperations.end() ||
729+
lowerBound->getDialect().getNamespace() != dialectName)
730+
return ArrayRef<RegisteredOperationName>();
731+
732+
size_t count =
733+
lowerBound->getDialect().getContext()->getImpl().operationCount;
734+
return ArrayRef(lowerBound, count);
735+
}
736+
714737
bool MLIRContext::isOperationRegistered(StringRef name) {
715738
return RegisteredOperationName::lookup(name, this).has_value();
716739
}
@@ -976,6 +999,8 @@ void RegisteredOperationName::insert(
976999
"operation name registration must be successful");
9771000

9781001
// Add emplaced operation name to the sorted operations container.
1002+
ctxImpl.operationCount += 1;
1003+
9791004
RegisteredOperationName &value = emplaced.first->second;
9801005
ctxImpl.sortedRegisteredOperations.insert(
9811006
llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,

0 commit comments

Comments
 (0)