@@ -188,7 +188,11 @@ class MLIRContextImpl {
188188
189189 // / This is a sorted container of registered operations for a deterministic
190190 // / and efficient `getRegisteredOperations` implementation.
191- SmallVector<RegisteredOperationName, 0 > sortedRegisteredOperations;
191+ SmallVector<std::pair<StringRef, RegisteredOperationName>, 0 >
192+ sortedRegisteredOperations;
193+
194+ // / This returns the number of registered operations for a given dialect.
195+ llvm::DenseMap<StringRef, size_t > getCountByDialectName;
192196
193197 // / This is a list of dialects that are created referring to this context.
194198 // / The MLIRContext owns the objects. These need to be declared after the
@@ -707,8 +711,31 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
707711}
708712
709713// / Return information about all registered operations.
710- ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations () {
711- return impl->sortedRegisteredOperations ;
714+ SmallVector<RegisteredOperationName, 0 > MLIRContext::getRegisteredOperations () {
715+ SmallVector<RegisteredOperationName, 0 > operations;
716+ std::transform (impl->sortedRegisteredOperations .begin (),
717+ impl->sortedRegisteredOperations .end (),
718+ std::back_inserter (operations),
719+ [](const auto &t) { return t.second ; });
720+
721+ return operations;
722+ }
723+
724+ // / Return information for registered operations by dialect.
725+ SmallVector<RegisteredOperationName, 0 >
726+ MLIRContext::getRegisteredOperationsByDialect (StringRef dialectName) {
727+ SmallVector<RegisteredOperationName, 0 > operations;
728+
729+ auto lowerBound = std::lower_bound (
730+ impl->sortedRegisteredOperations .begin (),
731+ impl->sortedRegisteredOperations .end (), std::make_pair (dialectName, " " ),
732+ [](auto &lhs, auto &rhs) { return lhs.first .compare (rhs.first ); });
733+ auto count = impl->getCountByDialectName [dialectName];
734+
735+ std::transform (lowerBound, lowerBound + count, std::back_inserter (operations),
736+ [](const auto &t) { return t.second ; });
737+
738+ return operations;
712739}
713740
714741bool MLIRContext::isOperationRegistered (StringRef name) {
@@ -976,14 +1003,20 @@ void RegisteredOperationName::insert(
9761003 " operation name registration must be successful" );
9771004
9781005 // Add emplaced operation name to the sorted operations container.
979- RegisteredOperationName &value = emplaced.first ->second ;
980- ctxImpl.sortedRegisteredOperations .insert (
981- llvm::upper_bound (ctxImpl.sortedRegisteredOperations , value,
982- [](auto &lhs, auto &rhs) {
983- return lhs.getIdentifier ().compare (
984- rhs.getIdentifier ());
985- }),
986- value);
1006+ StringRef dialectClass = impl->getDialect ()->getNamespace ();
1007+ ctxImpl.getCountByDialectName [dialectClass] += 1 ;
1008+
1009+ std::pair<StringRef, RegisteredOperationName> value = {
1010+ dialectClass, emplaced.first ->second };
1011+
1012+ auto upperBound = llvm::upper_bound (
1013+ ctxImpl.sortedRegisteredOperations , value, [](auto &lhs, auto &rhs) {
1014+ if (lhs.first == rhs.first )
1015+ return lhs.second .getIdentifier ().compare (rhs.second .getIdentifier ());
1016+ return lhs.first .compare (rhs.first );
1017+ });
1018+
1019+ ctxImpl.sortedRegisteredOperations .insert (upperBound, value);
9871020}
9881021
9891022// ===----------------------------------------------------------------------===//
0 commit comments