Skip to content

Commit 30677fb

Browse files
committed
[mlir] Method to iterate over registered operations for a given dialect class.
Part of #111591 Currently we have `MLIRContext::getRegisteredOperations` which returns all operations for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the same for a given dialect class.
1 parent b2cac3b commit 30677fb

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ class MLIRContext {
197197
/// operations.
198198
ArrayRef<RegisteredOperationName> getRegisteredOperations();
199199

200+
/// Return a sorted array containing the information for registered operations
201+
/// filtered by dialect name.
202+
ArrayRef<RegisteredOperationName>
203+
getRegisteredOperationsByDialect(StringRef dialectName);
204+
200205
/// Return true if this operation name is registered in this context.
201206
bool isOperationRegistered(StringRef name);
202207

mlir/lib/IR/MLIRContext.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ class MLIRContextImpl {
190190
/// and efficient `getRegisteredOperations` implementation.
191191
SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
192192

193+
/// This returns the number of registered operations for a given dialect.
194+
DenseMap<StringRef, size_t> getCountByDialectName;
195+
193196
/// This is a list of dialects that are created referring to this context.
194197
/// The MLIRContext owns the objects. These need to be declared after the
195198
/// registered operations to ensure correct destruction order.
@@ -711,6 +714,21 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
711714
return impl->sortedRegisteredOperations;
712715
}
713716

717+
/// Return information for registered operations by dialect.
718+
ArrayRef<RegisteredOperationName>
719+
MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
720+
auto lowerBound =
721+
std::lower_bound(impl->sortedRegisteredOperations.begin(),
722+
impl->sortedRegisteredOperations.end(), dialectName,
723+
[](auto &lhs, auto &rhs) {
724+
return lhs.getDialect().getNamespace().compare(
725+
rhs.getDialect().getNamespace());
726+
});
727+
auto count = impl->getCountByDialectName[dialectName];
728+
729+
return ArrayRef(impl->sortedRegisteredOperations.data(), count);
730+
}
731+
714732
bool MLIRContext::isOperationRegistered(StringRef name) {
715733
return RegisteredOperationName::lookup(name, this).has_value();
716734
}
@@ -976,12 +994,19 @@ void RegisteredOperationName::insert(
976994
"operation name registration must be successful");
977995

978996
// Add emplaced operation name to the sorted operations container.
997+
StringRef dialectClass = impl->getDialect()->getNamespace();
998+
ctxImpl.getCountByDialectName[dialectClass] += 1;
999+
9791000
RegisteredOperationName &value = emplaced.first->second;
9801001
ctxImpl.sortedRegisteredOperations.insert(
9811002
llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
9821003
[](auto &lhs, auto &rhs) {
983-
return lhs.getIdentifier().compare(
984-
rhs.getIdentifier());
1004+
if (lhs.getDialect().getNamespace() ==
1005+
rhs.getDialect().getNamespace())
1006+
return lhs.getIdentifier().compare(
1007+
rhs.getIdentifier());
1008+
return lhs.getDialect().getNamespace().compare(
1009+
rhs.getDialect().getNamespace());
9851010
}),
9861011
value);
9871012
}

0 commit comments

Comments
 (0)