Skip to content

Commit 41ec7af

Browse files
committed
[mlir] Method to iterate over registered operations for a given dialect class.
Part of #111591 Currently we have `MLIRContext::getRegisteredOperations` which returns all operations for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the same for a given dialect class.
1 parent b2cac3b commit 41ec7af

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ class MLIRContext {
197197
/// operations.
198198
ArrayRef<RegisteredOperationName> getRegisteredOperations();
199199

200+
/// Return a sorted array containing the information for registered operations
201+
/// filtered by dialect name.
202+
ArrayRef<RegisteredOperationName>
203+
getRegisteredOperationsByDialect(StringRef dialectName);
204+
200205
/// Return true if this operation name is registered in this context.
201206
bool isOperationRegistered(StringRef name);
202207

mlir/lib/IR/MLIRContext.cpp

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,15 @@ class MLIRContextImpl {
188188

189189
/// This is a sorted container of registered operations for a deterministic
190190
/// and efficient `getRegisteredOperations` implementation.
191-
SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
191+
SmallVector<std::pair<StringRef, RegisteredOperationName>, 0>
192+
sortedRegisteredOperations;
193+
194+
/// This stores the transformed operations when calling
195+
/// `getRegisteredOperations`.
196+
SmallVector<RegisteredOperationName, 0> transformedOperations;
197+
198+
/// This returns the number of registered operations for a given dialect.
199+
llvm::DenseMap<StringRef, size_t> getCountByDialectName;
192200

193201
/// This is a list of dialects that are created referring to this context.
194202
/// The MLIRContext owns the objects. These need to be declared after the
@@ -708,7 +716,33 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
708716

709717
/// Return information about all registered operations.
710718
ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
711-
return impl->sortedRegisteredOperations;
719+
impl->transformedOperations.clear();
720+
721+
SmallVector<RegisteredOperationName, 0> operations;
722+
std::transform(impl->sortedRegisteredOperations.begin(),
723+
impl->sortedRegisteredOperations.end(),
724+
std::back_inserter(impl->transformedOperations),
725+
[](const auto &t) { return t.second; });
726+
727+
return impl->transformedOperations;
728+
}
729+
730+
/// Return information for registered operations by dialect.
731+
ArrayRef<RegisteredOperationName>
732+
MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
733+
impl->transformedOperations.clear();
734+
735+
auto lowerBound = std::lower_bound(
736+
impl->sortedRegisteredOperations.begin(),
737+
impl->sortedRegisteredOperations.end(), std::make_pair(dialectName, ""),
738+
[](auto &lhs, auto &rhs) { return lhs.first.compare(rhs.first); });
739+
auto count = impl->getCountByDialectName[dialectName];
740+
741+
std::transform(lowerBound, lowerBound + count,
742+
std::back_inserter(impl->transformedOperations),
743+
[](const auto &t) { return t.second; });
744+
745+
return impl->transformedOperations;
712746
}
713747

714748
bool MLIRContext::isOperationRegistered(StringRef name) {
@@ -976,14 +1010,20 @@ void RegisteredOperationName::insert(
9761010
"operation name registration must be successful");
9771011

9781012
// Add emplaced operation name to the sorted operations container.
979-
RegisteredOperationName &value = emplaced.first->second;
980-
ctxImpl.sortedRegisteredOperations.insert(
981-
llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
982-
[](auto &lhs, auto &rhs) {
983-
return lhs.getIdentifier().compare(
984-
rhs.getIdentifier());
985-
}),
986-
value);
1013+
StringRef dialectClass = impl->getDialect()->getNamespace();
1014+
ctxImpl.getCountByDialectName[dialectClass] += 1;
1015+
1016+
std::pair<StringRef, RegisteredOperationName> value = {
1017+
dialectClass, emplaced.first->second};
1018+
1019+
auto upperBound = llvm::upper_bound(
1020+
ctxImpl.sortedRegisteredOperations, value, [](auto &lhs, auto &rhs) {
1021+
if (lhs.first == rhs.first)
1022+
return lhs.second.getIdentifier().compare(rhs.second.getIdentifier());
1023+
return lhs.first.compare(rhs.first);
1024+
});
1025+
1026+
ctxImpl.sortedRegisteredOperations.insert(upperBound, value);
9871027
}
9881028

9891029
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)