@@ -188,7 +188,15 @@ class MLIRContextImpl {
188188
189189 // / This is a sorted container of registered operations for a deterministic
190190 // / and efficient `getRegisteredOperations` implementation.
191- SmallVector<RegisteredOperationName, 0 > sortedRegisteredOperations;
191+ SmallVector<std::pair<StringRef, RegisteredOperationName>, 0 >
192+ sortedRegisteredOperations;
193+
194+ // / This stores the transformed operations when calling
195+ // / `getRegisteredOperations`.
196+ SmallVector<RegisteredOperationName, 0 > transformedOperations;
197+
198+ // / This returns the number of registered operations for a given dialect.
199+ llvm::DenseMap<StringRef, size_t > getCountByDialectName;
192200
193201 // / This is a list of dialects that are created referring to this context.
194202 // / The MLIRContext owns the objects. These need to be declared after the
@@ -708,7 +716,33 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
708716
709717// / Return information about all registered operations.
710718ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations () {
711- return impl->sortedRegisteredOperations ;
719+ impl->transformedOperations .clear ();
720+
721+ SmallVector<RegisteredOperationName, 0 > operations;
722+ std::transform (impl->sortedRegisteredOperations .begin (),
723+ impl->sortedRegisteredOperations .end (),
724+ std::back_inserter (impl->transformedOperations ),
725+ [](const auto &t) { return t.second ; });
726+
727+ return impl->transformedOperations ;
728+ }
729+
730+ // / Return information for registered operations by dialect.
731+ ArrayRef<RegisteredOperationName>
732+ MLIRContext::getRegisteredOperationsByDialect (StringRef dialectName) {
733+ impl->transformedOperations .clear ();
734+
735+ auto lowerBound = std::lower_bound (
736+ impl->sortedRegisteredOperations .begin (),
737+ impl->sortedRegisteredOperations .end (), std::make_pair (dialectName, " " ),
738+ [](auto &lhs, auto &rhs) { return lhs.first .compare (rhs.first ); });
739+ auto count = impl->getCountByDialectName [dialectName];
740+
741+ std::transform (lowerBound, lowerBound + count,
742+ std::back_inserter (impl->transformedOperations ),
743+ [](const auto &t) { return t.second ; });
744+
745+ return impl->transformedOperations ;
712746}
713747
714748bool MLIRContext::isOperationRegistered (StringRef name) {
@@ -976,14 +1010,20 @@ void RegisteredOperationName::insert(
9761010 " operation name registration must be successful" );
9771011
9781012 // Add emplaced operation name to the sorted operations container.
979- RegisteredOperationName &value = emplaced.first ->second ;
980- ctxImpl.sortedRegisteredOperations .insert (
981- llvm::upper_bound (ctxImpl.sortedRegisteredOperations , value,
982- [](auto &lhs, auto &rhs) {
983- return lhs.getIdentifier ().compare (
984- rhs.getIdentifier ());
985- }),
986- value);
1013+ StringRef dialectClass = impl->getDialect ()->getNamespace ();
1014+ ctxImpl.getCountByDialectName [dialectClass] += 1 ;
1015+
1016+ std::pair<StringRef, RegisteredOperationName> value = {
1017+ dialectClass, emplaced.first ->second };
1018+
1019+ auto upperBound = llvm::upper_bound (
1020+ ctxImpl.sortedRegisteredOperations , value, [](auto &lhs, auto &rhs) {
1021+ if (lhs.first == rhs.first )
1022+ return lhs.second .getIdentifier ().compare (rhs.second .getIdentifier ());
1023+ return lhs.first .compare (rhs.first );
1024+ });
1025+
1026+ ctxImpl.sortedRegisteredOperations .insert (upperBound, value);
9871027}
9881028
9891029// ===----------------------------------------------------------------------===//
0 commit comments