-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir] Method to iterate over registered operations for a given dialect class. #112344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-core Author: Rajveer Singh Bharadwaj (Rajveer100) ChangesPart of #111591 Currently we have Full diff: https://github.com/llvm/llvm-project/pull/112344.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d17bbac81655b5..cfad6874b8f4a9 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -195,7 +195,12 @@ class MLIRContext {
/// Return a sorted array containing the information about all registered
/// operations.
- ArrayRef<RegisteredOperationName> getRegisteredOperations();
+ SmallVector<RegisteredOperationName, 0> getRegisteredOperations();
+
+ /// Return a sorted array containing the information for registered operations
+ /// filtered by dialect name.
+ SmallVector<RegisteredOperationName, 0>
+ getRegisteredOperationsByDialect(StringRef dialectName);
/// Return true if this operation name is registered in this context.
bool isOperationRegistered(StringRef name);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f05666fcde207b..bb0da94e985517 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -188,7 +188,11 @@ class MLIRContextImpl {
/// This is a sorted container of registered operations for a deterministic
/// and efficient `getRegisteredOperations` implementation.
- SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
+ SmallVector<std::pair<StringRef, RegisteredOperationName>, 0>
+ sortedRegisteredOperations;
+
+ /// This returns the number of registered operations for a given dialect.
+ llvm::DenseMap<StringRef, size_t> getCountByDialectName;
/// This is a list of dialects that are created referring to this context.
/// The MLIRContext owns the objects. These need to be declared after the
@@ -707,8 +711,31 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
}
/// Return information about all registered operations.
-ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
- return impl->sortedRegisteredOperations;
+SmallVector<RegisteredOperationName, 0> MLIRContext::getRegisteredOperations() {
+ SmallVector<RegisteredOperationName, 0> operations;
+ std::transform(impl->sortedRegisteredOperations.begin(),
+ impl->sortedRegisteredOperations.end(),
+ std::back_inserter(operations),
+ [](const auto &t) { return t.second; });
+
+ return operations;
+}
+
+/// Return information for registered operations by dialect.
+SmallVector<RegisteredOperationName, 0>
+MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
+ SmallVector<RegisteredOperationName, 0> operations;
+
+ auto lowerBound = std::lower_bound(
+ impl->sortedRegisteredOperations.begin(),
+ impl->sortedRegisteredOperations.end(), std::make_pair(dialectName, ""),
+ [](auto &lhs, auto &rhs) { return lhs.first.compare(rhs.first); });
+ auto count = impl->getCountByDialectName[dialectName];
+
+ std::transform(lowerBound, lowerBound + count, std::back_inserter(operations),
+ [](const auto &t) { return t.second; });
+
+ return operations;
}
bool MLIRContext::isOperationRegistered(StringRef name) {
@@ -976,14 +1003,20 @@ void RegisteredOperationName::insert(
"operation name registration must be successful");
// Add emplaced operation name to the sorted operations container.
- RegisteredOperationName &value = emplaced.first->second;
- ctxImpl.sortedRegisteredOperations.insert(
- llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
- [](auto &lhs, auto &rhs) {
- return lhs.getIdentifier().compare(
- rhs.getIdentifier());
- }),
- value);
+ StringRef dialectClass = impl->getDialect()->getNamespace();
+ ctxImpl.getCountByDialectName[dialectClass] += 1;
+
+ std::pair<StringRef, RegisteredOperationName> value = {
+ dialectClass, emplaced.first->second};
+
+ auto upperBound = llvm::upper_bound(
+ ctxImpl.sortedRegisteredOperations, value, [](auto &lhs, auto &rhs) {
+ if (lhs.first == rhs.first)
+ return lhs.second.getIdentifier().compare(rhs.second.getIdentifier());
+ return lhs.first.compare(rhs.first);
+ });
+
+ ctxImpl.sortedRegisteredOperations.insert(upperBound, value);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
index 17fe02df9f66cd..d3317fc6d4fe30 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
@@ -73,7 +73,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
// Functor used to walk all of the operations registered in the context. This
// is useful for patterns that get applied to multiple operations, such as
// interface and trait based patterns.
- std::vector<RegisteredOperationName> opInfos;
+ SmallVector<RegisteredOperationName> opInfos;
auto addToOpsWhen =
[&](std::unique_ptr<RewritePattern> &pattern,
function_ref<bool(RegisteredOperationName)> callbackFn) {
|
|
@llvm/pr-subscribers-mlir Author: Rajveer Singh Bharadwaj (Rajveer100) ChangesPart of #111591 Currently we have Full diff: https://github.com/llvm/llvm-project/pull/112344.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d17bbac81655b5..cfad6874b8f4a9 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -195,7 +195,12 @@ class MLIRContext {
/// Return a sorted array containing the information about all registered
/// operations.
- ArrayRef<RegisteredOperationName> getRegisteredOperations();
+ SmallVector<RegisteredOperationName, 0> getRegisteredOperations();
+
+ /// Return a sorted array containing the information for registered operations
+ /// filtered by dialect name.
+ SmallVector<RegisteredOperationName, 0>
+ getRegisteredOperationsByDialect(StringRef dialectName);
/// Return true if this operation name is registered in this context.
bool isOperationRegistered(StringRef name);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f05666fcde207b..bb0da94e985517 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -188,7 +188,11 @@ class MLIRContextImpl {
/// This is a sorted container of registered operations for a deterministic
/// and efficient `getRegisteredOperations` implementation.
- SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
+ SmallVector<std::pair<StringRef, RegisteredOperationName>, 0>
+ sortedRegisteredOperations;
+
+ /// This returns the number of registered operations for a given dialect.
+ llvm::DenseMap<StringRef, size_t> getCountByDialectName;
/// This is a list of dialects that are created referring to this context.
/// The MLIRContext owns the objects. These need to be declared after the
@@ -707,8 +711,31 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
}
/// Return information about all registered operations.
-ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
- return impl->sortedRegisteredOperations;
+SmallVector<RegisteredOperationName, 0> MLIRContext::getRegisteredOperations() {
+ SmallVector<RegisteredOperationName, 0> operations;
+ std::transform(impl->sortedRegisteredOperations.begin(),
+ impl->sortedRegisteredOperations.end(),
+ std::back_inserter(operations),
+ [](const auto &t) { return t.second; });
+
+ return operations;
+}
+
+/// Return information for registered operations by dialect.
+SmallVector<RegisteredOperationName, 0>
+MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
+ SmallVector<RegisteredOperationName, 0> operations;
+
+ auto lowerBound = std::lower_bound(
+ impl->sortedRegisteredOperations.begin(),
+ impl->sortedRegisteredOperations.end(), std::make_pair(dialectName, ""),
+ [](auto &lhs, auto &rhs) { return lhs.first.compare(rhs.first); });
+ auto count = impl->getCountByDialectName[dialectName];
+
+ std::transform(lowerBound, lowerBound + count, std::back_inserter(operations),
+ [](const auto &t) { return t.second; });
+
+ return operations;
}
bool MLIRContext::isOperationRegistered(StringRef name) {
@@ -976,14 +1003,20 @@ void RegisteredOperationName::insert(
"operation name registration must be successful");
// Add emplaced operation name to the sorted operations container.
- RegisteredOperationName &value = emplaced.first->second;
- ctxImpl.sortedRegisteredOperations.insert(
- llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
- [](auto &lhs, auto &rhs) {
- return lhs.getIdentifier().compare(
- rhs.getIdentifier());
- }),
- value);
+ StringRef dialectClass = impl->getDialect()->getNamespace();
+ ctxImpl.getCountByDialectName[dialectClass] += 1;
+
+ std::pair<StringRef, RegisteredOperationName> value = {
+ dialectClass, emplaced.first->second};
+
+ auto upperBound = llvm::upper_bound(
+ ctxImpl.sortedRegisteredOperations, value, [](auto &lhs, auto &rhs) {
+ if (lhs.first == rhs.first)
+ return lhs.second.getIdentifier().compare(rhs.second.getIdentifier());
+ return lhs.first.compare(rhs.first);
+ });
+
+ ctxImpl.sortedRegisteredOperations.insert(upperBound, value);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
index 17fe02df9f66cd..d3317fc6d4fe30 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
@@ -73,7 +73,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
// Functor used to walk all of the operations registered in the context. This
// is useful for patterns that get applied to multiple operations, such as
// interface and trait based patterns.
- std::vector<RegisteredOperationName> opInfos;
+ SmallVector<RegisteredOperationName> opInfos;
auto addToOpsWhen =
[&](std::unique_ptr<RewritePattern> &pattern,
function_ref<bool(RegisteredOperationName)> callbackFn) {
|
|
Before proceeding with additional test cases, let me know if this works well. I am sure there are many logical ways to approach this. |
|
Per usual, Mehdi beat me to the review. Thank you for the PR. |
a63e19f to
41ec7af
Compare
41ec7af to
859a8ac
Compare
|
Let me know if the updated logic works well. |
859a8ac to
0ca27ec
Compare
0ca27ec to
30677fb
Compare
30677fb to
469021b
Compare
a0849d7 to
3c94c22
Compare
|
I have made the changes, I think it should make sense now. PS: This could have been way smoother :\ |
3c94c22 to
644ab88
Compare
644ab88 to
d123ad6
Compare
d123ad6 to
70b35e9
Compare
fe978d4 to
5ffba9a
Compare
…ct class. Part of llvm#111591 Currently we have `MLIRContext::getRegisteredOperations` which returns all operations for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the same for a given dialect class.
5ffba9a to
db49a12
Compare
joker-eph
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks.
|
Thanks for the approval, could you land this for me?! |
Part of #111591
Currently we have
MLIRContext::getRegisteredOperationswhich returns all operations for the given context, with the addition ofMLIRContext::getRegisteredOperationsByDialectwe can now retrieve the same for a given dialect class.