Skip to content

Conversation

@Rajveer100
Copy link
Member

Part of #111591

Currently we have MLIRContext::getRegisteredOperations which returns all operations for the given context, with the addition of MLIRContext::getRegisteredOperationsByDialect we can now retrieve the same for a given dialect class.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Oct 15, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-mlir-core

Author: Rajveer Singh Bharadwaj (Rajveer100)

Changes

Part of #111591

Currently we have MLIRContext::getRegisteredOperations which returns all operations for the given context, with the addition of MLIRContext::getRegisteredOperationsByDialect we can now retrieve the same for a given dialect class.


Full diff: https://github.com/llvm/llvm-project/pull/112344.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/MLIRContext.h (+6-1)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+44-11)
  • (modified) mlir/lib/Rewrite/FrozenRewritePatternSet.cpp (+1-1)
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d17bbac81655b5..cfad6874b8f4a9 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -195,7 +195,12 @@ class MLIRContext {
 
   /// Return a sorted array containing the information about all registered
   /// operations.
-  ArrayRef<RegisteredOperationName> getRegisteredOperations();
+  SmallVector<RegisteredOperationName, 0> getRegisteredOperations();
+
+  /// Return a sorted array containing the information for registered operations
+  /// filtered by dialect name.
+  SmallVector<RegisteredOperationName, 0>
+  getRegisteredOperationsByDialect(StringRef dialectName);
 
   /// Return true if this operation name is registered in this context.
   bool isOperationRegistered(StringRef name);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f05666fcde207b..bb0da94e985517 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -188,7 +188,11 @@ class MLIRContextImpl {
 
   /// This is a sorted container of registered operations for a deterministic
   /// and efficient `getRegisteredOperations` implementation.
-  SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
+  SmallVector<std::pair<StringRef, RegisteredOperationName>, 0>
+      sortedRegisteredOperations;
+
+  /// This returns the number of registered operations for a given dialect.
+  llvm::DenseMap<StringRef, size_t> getCountByDialectName;
 
   /// This is a list of dialects that are created referring to this context.
   /// The MLIRContext owns the objects. These need to be declared after the
@@ -707,8 +711,31 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
 }
 
 /// Return information about all registered operations.
-ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
-  return impl->sortedRegisteredOperations;
+SmallVector<RegisteredOperationName, 0> MLIRContext::getRegisteredOperations() {
+  SmallVector<RegisteredOperationName, 0> operations;
+  std::transform(impl->sortedRegisteredOperations.begin(),
+                 impl->sortedRegisteredOperations.end(),
+                 std::back_inserter(operations),
+                 [](const auto &t) { return t.second; });
+
+  return operations;
+}
+
+/// Return information for registered operations by dialect.
+SmallVector<RegisteredOperationName, 0>
+MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
+  SmallVector<RegisteredOperationName, 0> operations;
+
+  auto lowerBound = std::lower_bound(
+      impl->sortedRegisteredOperations.begin(),
+      impl->sortedRegisteredOperations.end(), std::make_pair(dialectName, ""),
+      [](auto &lhs, auto &rhs) { return lhs.first.compare(rhs.first); });
+  auto count = impl->getCountByDialectName[dialectName];
+
+  std::transform(lowerBound, lowerBound + count, std::back_inserter(operations),
+                 [](const auto &t) { return t.second; });
+
+  return operations;
 }
 
 bool MLIRContext::isOperationRegistered(StringRef name) {
@@ -976,14 +1003,20 @@ void RegisteredOperationName::insert(
          "operation name registration must be successful");
 
   // Add emplaced operation name to the sorted operations container.
-  RegisteredOperationName &value = emplaced.first->second;
-  ctxImpl.sortedRegisteredOperations.insert(
-      llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
-                        [](auto &lhs, auto &rhs) {
-                          return lhs.getIdentifier().compare(
-                              rhs.getIdentifier());
-                        }),
-      value);
+  StringRef dialectClass = impl->getDialect()->getNamespace();
+  ctxImpl.getCountByDialectName[dialectClass] += 1;
+
+  std::pair<StringRef, RegisteredOperationName> value = {
+      dialectClass, emplaced.first->second};
+
+  auto upperBound = llvm::upper_bound(
+      ctxImpl.sortedRegisteredOperations, value, [](auto &lhs, auto &rhs) {
+        if (lhs.first == rhs.first)
+          return lhs.second.getIdentifier().compare(rhs.second.getIdentifier());
+        return lhs.first.compare(rhs.first);
+      });
+
+  ctxImpl.sortedRegisteredOperations.insert(upperBound, value);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
index 17fe02df9f66cd..d3317fc6d4fe30 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
@@ -73,7 +73,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
   // Functor used to walk all of the operations registered in the context. This
   // is useful for patterns that get applied to multiple operations, such as
   // interface and trait based patterns.
-  std::vector<RegisteredOperationName> opInfos;
+  SmallVector<RegisteredOperationName> opInfos;
   auto addToOpsWhen =
       [&](std::unique_ptr<RewritePattern> &pattern,
           function_ref<bool(RegisteredOperationName)> callbackFn) {

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-mlir

Author: Rajveer Singh Bharadwaj (Rajveer100)

Changes

Part of #111591

Currently we have MLIRContext::getRegisteredOperations which returns all operations for the given context, with the addition of MLIRContext::getRegisteredOperationsByDialect we can now retrieve the same for a given dialect class.


Full diff: https://github.com/llvm/llvm-project/pull/112344.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/MLIRContext.h (+6-1)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+44-11)
  • (modified) mlir/lib/Rewrite/FrozenRewritePatternSet.cpp (+1-1)
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d17bbac81655b5..cfad6874b8f4a9 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -195,7 +195,12 @@ class MLIRContext {
 
   /// Return a sorted array containing the information about all registered
   /// operations.
-  ArrayRef<RegisteredOperationName> getRegisteredOperations();
+  SmallVector<RegisteredOperationName, 0> getRegisteredOperations();
+
+  /// Return a sorted array containing the information for registered operations
+  /// filtered by dialect name.
+  SmallVector<RegisteredOperationName, 0>
+  getRegisteredOperationsByDialect(StringRef dialectName);
 
   /// Return true if this operation name is registered in this context.
   bool isOperationRegistered(StringRef name);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f05666fcde207b..bb0da94e985517 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -188,7 +188,11 @@ class MLIRContextImpl {
 
   /// This is a sorted container of registered operations for a deterministic
   /// and efficient `getRegisteredOperations` implementation.
-  SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
+  SmallVector<std::pair<StringRef, RegisteredOperationName>, 0>
+      sortedRegisteredOperations;
+
+  /// This returns the number of registered operations for a given dialect.
+  llvm::DenseMap<StringRef, size_t> getCountByDialectName;
 
   /// This is a list of dialects that are created referring to this context.
   /// The MLIRContext owns the objects. These need to be declared after the
@@ -707,8 +711,31 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
 }
 
 /// Return information about all registered operations.
-ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
-  return impl->sortedRegisteredOperations;
+SmallVector<RegisteredOperationName, 0> MLIRContext::getRegisteredOperations() {
+  SmallVector<RegisteredOperationName, 0> operations;
+  std::transform(impl->sortedRegisteredOperations.begin(),
+                 impl->sortedRegisteredOperations.end(),
+                 std::back_inserter(operations),
+                 [](const auto &t) { return t.second; });
+
+  return operations;
+}
+
+/// Return information for registered operations by dialect.
+SmallVector<RegisteredOperationName, 0>
+MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
+  SmallVector<RegisteredOperationName, 0> operations;
+
+  auto lowerBound = std::lower_bound(
+      impl->sortedRegisteredOperations.begin(),
+      impl->sortedRegisteredOperations.end(), std::make_pair(dialectName, ""),
+      [](auto &lhs, auto &rhs) { return lhs.first.compare(rhs.first); });
+  auto count = impl->getCountByDialectName[dialectName];
+
+  std::transform(lowerBound, lowerBound + count, std::back_inserter(operations),
+                 [](const auto &t) { return t.second; });
+
+  return operations;
 }
 
 bool MLIRContext::isOperationRegistered(StringRef name) {
@@ -976,14 +1003,20 @@ void RegisteredOperationName::insert(
          "operation name registration must be successful");
 
   // Add emplaced operation name to the sorted operations container.
-  RegisteredOperationName &value = emplaced.first->second;
-  ctxImpl.sortedRegisteredOperations.insert(
-      llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
-                        [](auto &lhs, auto &rhs) {
-                          return lhs.getIdentifier().compare(
-                              rhs.getIdentifier());
-                        }),
-      value);
+  StringRef dialectClass = impl->getDialect()->getNamespace();
+  ctxImpl.getCountByDialectName[dialectClass] += 1;
+
+  std::pair<StringRef, RegisteredOperationName> value = {
+      dialectClass, emplaced.first->second};
+
+  auto upperBound = llvm::upper_bound(
+      ctxImpl.sortedRegisteredOperations, value, [](auto &lhs, auto &rhs) {
+        if (lhs.first == rhs.first)
+          return lhs.second.getIdentifier().compare(rhs.second.getIdentifier());
+        return lhs.first.compare(rhs.first);
+      });
+
+  ctxImpl.sortedRegisteredOperations.insert(upperBound, value);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
index 17fe02df9f66cd..d3317fc6d4fe30 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
@@ -73,7 +73,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
   // Functor used to walk all of the operations registered in the context. This
   // is useful for patterns that get applied to multiple operations, such as
   // interface and trait based patterns.
-  std::vector<RegisteredOperationName> opInfos;
+  SmallVector<RegisteredOperationName> opInfos;
   auto addToOpsWhen =
       [&](std::unique_ptr<RewritePattern> &pattern,
           function_ref<bool(RegisteredOperationName)> callbackFn) {

@Rajveer100
Copy link
Member Author

@stellaraccident

@Rajveer100
Copy link
Member Author

Rajveer100 commented Oct 15, 2024

Before proceeding with additional test cases, let me know if this works well. I am sure there are many logical ways to approach this.

@stellaraccident
Copy link
Contributor

Per usual, Mehdi beat me to the review. Thank you for the PR.

@Rajveer100 Rajveer100 force-pushed the subrange-dialect-ops branch from a63e19f to 41ec7af Compare October 16, 2024 10:32
@Rajveer100 Rajveer100 force-pushed the subrange-dialect-ops branch from 41ec7af to 859a8ac Compare October 16, 2024 18:20
@Rajveer100
Copy link
Member Author

Let me know if the updated logic works well.

@Rajveer100 Rajveer100 force-pushed the subrange-dialect-ops branch from 859a8ac to 0ca27ec Compare October 16, 2024 18:22
@Rajveer100 Rajveer100 force-pushed the subrange-dialect-ops branch from 30677fb to 469021b Compare October 16, 2024 19:20
@Rajveer100 Rajveer100 force-pushed the subrange-dialect-ops branch 2 times, most recently from a0849d7 to 3c94c22 Compare October 16, 2024 19:31
@Rajveer100
Copy link
Member Author

Rajveer100 commented Oct 16, 2024

I have made the changes, I think it should make sense now.

PS: This could have been way smoother :\

@Rajveer100 Rajveer100 force-pushed the subrange-dialect-ops branch from 3c94c22 to 644ab88 Compare October 16, 2024 20:05
@Rajveer100 Rajveer100 requested a review from River707 October 16, 2024 20:05
@Rajveer100 Rajveer100 force-pushed the subrange-dialect-ops branch from 644ab88 to d123ad6 Compare October 16, 2024 20:13
@Rajveer100 Rajveer100 requested a review from joker-eph October 16, 2024 20:14
@Rajveer100 Rajveer100 force-pushed the subrange-dialect-ops branch from d123ad6 to 70b35e9 Compare October 17, 2024 06:34
@Rajveer100 Rajveer100 requested a review from joker-eph October 17, 2024 06:35
@Rajveer100 Rajveer100 force-pushed the subrange-dialect-ops branch 2 times, most recently from fe978d4 to 5ffba9a Compare October 17, 2024 07:35
…ct class.

Part of llvm#111591

Currently we have `MLIRContext::getRegisteredOperations` which returns all operations
for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect`
we can now retrieve the same for a given dialect class.
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks.

@Rajveer100
Copy link
Member Author

Thanks for the approval, could you land this for me?!

@joker-eph joker-eph merged commit b091701 into llvm:main Oct 17, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants