@@ -190,6 +190,9 @@ class MLIRContextImpl {
190190 // / and efficient `getRegisteredOperations` implementation.
191191 SmallVector<RegisteredOperationName, 0 > sortedRegisteredOperations;
192192
193+ // / This returns the number of registered operations for a given dialect.
194+ DenseMap<StringRef, size_t > getCountByDialectName;
195+
193196 // / This is a list of dialects that are created referring to this context.
194197 // / The MLIRContext owns the objects. These need to be declared after the
195198 // / registered operations to ensure correct destruction order.
@@ -711,6 +714,21 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
711714 return impl->sortedRegisteredOperations ;
712715}
713716
717+ // / Return information for registered operations by dialect.
718+ ArrayRef<RegisteredOperationName>
719+ MLIRContext::getRegisteredOperationsByDialect (StringRef dialectName) {
720+ auto lowerBound =
721+ std::lower_bound (impl->sortedRegisteredOperations .begin (),
722+ impl->sortedRegisteredOperations .end (), dialectName,
723+ [](auto &lhs, auto &rhs) {
724+ return lhs.getDialect ().getNamespace ().compare (
725+ rhs.getDialect ().getNamespace ());
726+ });
727+ auto count = impl->getCountByDialectName [dialectName];
728+
729+ return ArrayRef (impl->sortedRegisteredOperations .data (), count);
730+ }
731+
714732bool MLIRContext::isOperationRegistered (StringRef name) {
715733 return RegisteredOperationName::lookup (name, this ).has_value ();
716734}
@@ -976,12 +994,19 @@ void RegisteredOperationName::insert(
976994 " operation name registration must be successful" );
977995
978996 // Add emplaced operation name to the sorted operations container.
997+ StringRef dialectClass = impl->getDialect ()->getNamespace ();
998+ ctxImpl.getCountByDialectName [dialectClass] += 1 ;
999+
9791000 RegisteredOperationName &value = emplaced.first ->second ;
9801001 ctxImpl.sortedRegisteredOperations .insert (
9811002 llvm::upper_bound (ctxImpl.sortedRegisteredOperations , value,
9821003 [](auto &lhs, auto &rhs) {
983- return lhs.getIdentifier ().compare (
984- rhs.getIdentifier ());
1004+ if (lhs.getDialect ().getNamespace () ==
1005+ rhs.getDialect ().getNamespace ())
1006+ return lhs.getIdentifier ().compare (
1007+ rhs.getIdentifier ());
1008+ return lhs.getDialect ().getNamespace ().compare (
1009+ rhs.getDialect ().getNamespace ());
9851010 }),
9861011 value);
9871012}
0 commit comments