11// ===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
22//
3- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+ // Part of the LLVM Project, under the Apache License v2.0 wIth LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
5- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6- //
7- // ===----------------------------------------------------------------------===//
8- //
9- // Implements the base layer of the matcher framework.
10- //
11- // Matchers are methods that return a Matcher which provides a method
12- // match(Operation *op)
13- //
14- // The matcher functions are defined in include/mlir/IR/Matchers.h.
15- // This file contains the wrapper classes needed to construct matchers for
16- // mlir-query.
5+ // SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
176//
187// ===----------------------------------------------------------------------===//
198
2211
2312#include " mlir/IR/Matchers.h"
2413#include " llvm/ADT/IntrusiveRefCntPtr.h"
14+ #include " llvm/ADT/MapVector.h"
15+ #include < memory>
16+ #include < stack>
17+ #include < unordered_set>
18+ #include < vector>
2519
2620namespace mlir ::query::matcher {
2721
22+ struct BoundOperationNode {
23+ Operation *op;
24+ std::vector<BoundOperationNode *> Parents;
25+ std::vector<BoundOperationNode *> Children;
26+
27+ bool IsRootNode;
28+ bool DetailedPrinting;
29+
30+ BoundOperationNode (Operation *op, bool IsRootNode = false ,
31+ bool DetailedPrinting = false )
32+ : op(op), IsRootNode(IsRootNode), DetailedPrinting(DetailedPrinting) {}
33+ };
34+
35+ class BoundOperationsGraphBuilder {
36+ public:
37+ BoundOperationNode *addNode (Operation *op, bool IsRootNode = false ,
38+ bool DetailedPrinting = false ) {
39+ auto It = Nodes.find (op);
40+ if (It != Nodes.end ()) {
41+ return It->second .get ();
42+ }
43+ auto Node =
44+ std::make_unique<BoundOperationNode>(op, IsRootNode, DetailedPrinting);
45+ BoundOperationNode *NodePtr = Node.get ();
46+ Nodes[op] = std::move (Node);
47+ return NodePtr;
48+ }
49+
50+ void addEdge (Operation *parentOp, Operation *childOp) {
51+ BoundOperationNode *ParentNode = addNode (parentOp, false , false );
52+ BoundOperationNode *ChildNode = addNode (childOp, false , false );
53+
54+ ParentNode->Children .push_back (ChildNode);
55+ ChildNode->Parents .push_back (ParentNode);
56+ }
57+
58+ BoundOperationNode *getNode (Operation *op) const {
59+ auto It = Nodes.find (op);
60+ return It != Nodes.end () ? It->second .get () : nullptr ;
61+ }
62+
63+ const llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> &
64+ getNodes () const {
65+ return Nodes;
66+ }
67+
68+ private:
69+ llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> Nodes;
70+ };
71+
72+ // Type traIt to detect if a matcher has a match(Operation*) method
73+ template <typename T, typename = void >
74+ struct has_simple_match : std::false_type {};
75+
76+ template <typename T>
77+ struct has_simple_match <T, std::void_t <decltype (std::declval<T>().match(
78+ std::declval<Operation *>()))>>
79+ : std::true_type {};
80+
81+ // Type traIt to detect if a matcher has a match(Operation*,
82+ // BoundOperationsGraphBuilder&) method
83+ template <typename T, typename = void >
84+ struct has_bound_match : std::false_type {};
85+
86+ template <typename T>
87+ struct has_bound_match <T, std::void_t <decltype (std::declval<T>().match(
88+ std::declval<Operation *>(),
89+ std::declval<BoundOperationsGraphBuilder &>()))>>
90+ : std::true_type {};
91+
2892// Generic interface for matchers on an MLIR operation.
2993class MatcherInterface
3094 : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
3195public:
3296 virtual ~MatcherInterface () = default ;
33-
3497 virtual bool match (Operation *op) = 0;
98+ virtual bool match (Operation *op, BoundOperationsGraphBuilder &bound) = 0;
3599};
36100
37101// MatcherFnImpl takes a matcher function object and implements
@@ -40,40 +104,56 @@ template <typename MatcherFn>
40104class MatcherFnImpl : public MatcherInterface {
41105public:
42106 MatcherFnImpl (MatcherFn &matcherFn) : matcherFn(matcherFn) {}
43- bool match (Operation *op) override { return matcherFn.match (op); }
107+
108+ bool match (Operation *op) override {
109+ if constexpr (has_simple_match<MatcherFn>::value)
110+ return matcherFn.match (op);
111+ return false ;
112+ }
113+
114+ bool match (Operation *op, BoundOperationsGraphBuilder &bound) override {
115+ if constexpr (has_bound_match<MatcherFn>::value)
116+ return matcherFn.match (op, bound);
117+ return false ;
118+ }
44119
45120private:
46121 MatcherFn matcherFn;
47122};
48123
49- // Matcher wraps a MatcherInterface implementation and provides a match()
50- // method that redirects calls to the underlying implementation.
124+ // Matcher wraps a MatcherInterface implementation and provides match()
125+ // methods that redirect calls to the underlying implementation.
51126class DynMatcher {
52127public:
53128 // Takes ownership of the provided implementation pointer.
54- DynMatcher (MatcherInterface *implementation)
55- : implementation(implementation) {}
129+ DynMatcher (MatcherInterface *implementation, StringRef matcherName )
130+ : implementation(implementation), matcherName(matcherName.str()) {}
56131
57132 template <typename MatcherFn>
58133 static std::unique_ptr<DynMatcher>
59- constructDynMatcherFromMatcherFn (MatcherFn &matcherFn) {
134+ constructDynMatcherFromMatcherFn (MatcherFn &matcherFn,
135+ StringRef matcherName) {
60136 auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
61- return std::make_unique<DynMatcher>(impl.release ());
137+ return std::make_unique<DynMatcher>(impl.release (), matcherName );
62138 }
63139
64140 bool match (Operation *op) const { return implementation->match (op); }
141+ bool match (Operation *op, BoundOperationsGraphBuilder &bound) const {
142+ return implementation->match (op, bound);
143+ }
65144
66- void setFunctionName (StringRef name) { functionName = name.str (); };
67-
68- bool hasFunctionName () const { return !functionName.empty (); };
69-
70- StringRef getFunctionName () const { return functionName ; };
145+ void setFunctionName (StringRef name) { functionName = name.str (); }
146+ void setMatcherName (StringRef name) { matcherName = name. str (); }
147+ bool hasFunctionName () const { return !functionName.empty (); }
148+ StringRef getFunctionName () const { return functionName; }
149+ StringRef getMatcherName () const { return matcherName ; }
71150
72151private:
73152 llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
153+ std::string matcherName;
74154 std::string functionName;
75155};
76156
77157} // namespace mlir::query::matcher
78158
79- #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
159+ #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
0 commit comments