Skip to content

Commit a63e19f

Browse files
committed
[mlir] Method to iterate over registered operations for a given dialect class.
Part of #111591 Currently we have `MLIRContext::getRegisteredOperations` which returns all operations for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the same for a given dialect class.
1 parent b2cac3b commit a63e19f

File tree

3 files changed

+51
-13
lines changed

3 files changed

+51
-13
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,12 @@ class MLIRContext {
195195

196196
/// Return a sorted array containing the information about all registered
197197
/// operations.
198-
ArrayRef<RegisteredOperationName> getRegisteredOperations();
198+
SmallVector<RegisteredOperationName, 0> getRegisteredOperations();
199+
200+
/// Return a sorted array containing the information for registered operations
201+
/// filtered by dialect name.
202+
SmallVector<RegisteredOperationName, 0>
203+
getRegisteredOperationsByDialect(StringRef dialectName);
199204

200205
/// Return true if this operation name is registered in this context.
201206
bool isOperationRegistered(StringRef name);

mlir/lib/IR/MLIRContext.cpp

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,11 @@ class MLIRContextImpl {
188188

189189
/// This is a sorted container of registered operations for a deterministic
190190
/// and efficient `getRegisteredOperations` implementation.
191-
SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
191+
SmallVector<std::pair<StringRef, RegisteredOperationName>, 0>
192+
sortedRegisteredOperations;
193+
194+
/// This returns the number of registered operations for a given dialect.
195+
llvm::DenseMap<StringRef, size_t> getCountByDialectName;
192196

193197
/// This is a list of dialects that are created referring to this context.
194198
/// The MLIRContext owns the objects. These need to be declared after the
@@ -707,8 +711,31 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
707711
}
708712

709713
/// Return information about all registered operations.
710-
ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
711-
return impl->sortedRegisteredOperations;
714+
SmallVector<RegisteredOperationName, 0> MLIRContext::getRegisteredOperations() {
715+
SmallVector<RegisteredOperationName, 0> operations;
716+
std::transform(impl->sortedRegisteredOperations.begin(),
717+
impl->sortedRegisteredOperations.end(),
718+
std::back_inserter(operations),
719+
[](const auto &t) { return t.second; });
720+
721+
return operations;
722+
}
723+
724+
/// Return information for registered operations by dialect.
725+
SmallVector<RegisteredOperationName, 0>
726+
MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
727+
SmallVector<RegisteredOperationName, 0> operations;
728+
729+
auto lowerBound = std::lower_bound(
730+
impl->sortedRegisteredOperations.begin(),
731+
impl->sortedRegisteredOperations.end(), std::make_pair(dialectName, ""),
732+
[](auto &lhs, auto &rhs) { return lhs.first.compare(rhs.first); });
733+
auto count = impl->getCountByDialectName[dialectName];
734+
735+
std::transform(lowerBound, lowerBound + count, std::back_inserter(operations),
736+
[](const auto &t) { return t.second; });
737+
738+
return operations;
712739
}
713740

714741
bool MLIRContext::isOperationRegistered(StringRef name) {
@@ -976,14 +1003,20 @@ void RegisteredOperationName::insert(
9761003
"operation name registration must be successful");
9771004

9781005
// Add emplaced operation name to the sorted operations container.
979-
RegisteredOperationName &value = emplaced.first->second;
980-
ctxImpl.sortedRegisteredOperations.insert(
981-
llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
982-
[](auto &lhs, auto &rhs) {
983-
return lhs.getIdentifier().compare(
984-
rhs.getIdentifier());
985-
}),
986-
value);
1006+
StringRef dialectClass = impl->getDialect()->getNamespace();
1007+
ctxImpl.getCountByDialectName[dialectClass] += 1;
1008+
1009+
std::pair<StringRef, RegisteredOperationName> value = {
1010+
dialectClass, emplaced.first->second};
1011+
1012+
auto upperBound = llvm::upper_bound(
1013+
ctxImpl.sortedRegisteredOperations, value, [](auto &lhs, auto &rhs) {
1014+
if (lhs.first == rhs.first)
1015+
return lhs.second.getIdentifier().compare(rhs.second.getIdentifier());
1016+
return lhs.first.compare(rhs.first);
1017+
});
1018+
1019+
ctxImpl.sortedRegisteredOperations.insert(upperBound, value);
9871020
}
9881021

9891022
//===----------------------------------------------------------------------===//

mlir/lib/Rewrite/FrozenRewritePatternSet.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
7373
// Functor used to walk all of the operations registered in the context. This
7474
// is useful for patterns that get applied to multiple operations, such as
7575
// interface and trait based patterns.
76-
std::vector<RegisteredOperationName> opInfos;
76+
SmallVector<RegisteredOperationName> opInfos;
7777
auto addToOpsWhen =
7878
[&](std::unique_ptr<RewritePattern> &pattern,
7979
function_ref<bool(RegisteredOperationName)> callbackFn) {

0 commit comments

Comments
 (0)