Skip to content

Commit dc1c9d3

Browse files
authored
[flang][Evaluate] Pattern matching framework for evaluate::Expr (#153042)
Implement a framework to make it easier to detect if evaluate::Expr<T> has certain structure.
1 parent b75896b commit dc1c9d3

File tree

1 file changed

+211
-0
lines changed
  • flang/include/flang/Evaluate

1 file changed

+211
-0
lines changed

flang/include/flang/Evaluate/match.h

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
//===-- include/flang/Evaluate/match.h --------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef FORTRAN_EVALUATE_MATCH_H_
9+
#define FORTRAN_EVALUATE_MATCH_H_
10+
11+
#include "flang/Common/visit.h"
12+
#include "flang/Evaluate/expression.h"
13+
#include "llvm/ADT/STLExtras.h"
14+
15+
#include <tuple>
16+
#include <type_traits>
17+
#include <utility>
18+
#include <variant>
19+
20+
namespace Fortran::evaluate {
21+
namespace match {
22+
namespace detail {
23+
template <typename, typename = void> //
24+
struct IsOperation {
25+
static constexpr bool value{false};
26+
};
27+
28+
template <typename T>
29+
struct IsOperation<T, std::void_t<decltype(T::operands)>> {
30+
static constexpr bool value{true};
31+
};
32+
} // namespace detail
33+
34+
template <typename T>
35+
constexpr bool is_operation_v{detail::IsOperation<T>::value};
36+
37+
template <typename T>
38+
const evaluate::Expr<T> &deparen(const evaluate::Expr<T> &x) {
39+
if (auto *parens{std::get_if<evaluate::Parentheses<T>>(&x.u)}) {
40+
return deparen(parens->template operand<0>());
41+
} else {
42+
return x;
43+
}
44+
}
45+
46+
// Expr<T> matchers (patterns)
47+
//
48+
// Each pattern should implement
49+
// bool match(const U &input) const
50+
// member function that returns `true` when the match was successful,
51+
// and `false` otherwise.
52+
//
53+
// Patterns are intended to be composable, i.e. a pattern can take operands
54+
// which themselves are patterns. This composition is expected to match if
55+
// the root pattern and all its operands match given input.
56+
57+
/// Matches any input as long as it has the expected type `MatchType`.
58+
/// Additionally, it sets the member `ref` to the matched input.
59+
template <typename T> struct TypePattern {
60+
using MatchType = llvm::remove_cvref_t<T>;
61+
62+
template <typename U> bool match(const U &input) const {
63+
if constexpr (std::is_same_v<MatchType, U>) {
64+
ref = &input;
65+
return true;
66+
} else {
67+
return false;
68+
}
69+
}
70+
71+
mutable const MatchType *ref{nullptr};
72+
};
73+
74+
/// Matches one of the patterns provided as template arguments. All of these
75+
/// patterns should have the same number of operands, i.e. they all should
76+
/// try to match input expression with the same number of children, i.e.
77+
/// AnyOfPattern<SomeBinaryOp, OtherBinaryOp> is ok, whereas
78+
/// AnyOfPattern<SomeBinaryOp, SomeTernaryOp> is not.
79+
template <typename... Patterns> struct AnyOfPattern {
80+
static_assert(sizeof...(Patterns) != 0);
81+
82+
private:
83+
using PatternTuple = std::tuple<Patterns...>;
84+
85+
template <size_t I>
86+
using Pattern = typename std::tuple_element<I, PatternTuple>::type;
87+
88+
template <size_t... Is, typename... Ops>
89+
AnyOfPattern(std::index_sequence<Is...>, const Ops &...ops)
90+
: patterns(std::make_tuple(Pattern<Is>(ops...)...)) {}
91+
92+
template <typename P, typename U>
93+
bool matchOne(const P &pattern, const U &input) const {
94+
if (pattern.match(input)) {
95+
ref = &pattern;
96+
return true;
97+
}
98+
return false;
99+
}
100+
101+
template <typename U, size_t... Is>
102+
bool matchImpl(const U &input, std::index_sequence<Is...>) const {
103+
return (matchOne(std::get<Is>(patterns), input) || ...);
104+
}
105+
106+
PatternTuple patterns;
107+
108+
public:
109+
using Indexes = std::index_sequence_for<Patterns...>;
110+
using MatchTypes = std::tuple<typename Patterns::MatchType...>;
111+
112+
template <typename... Ops>
113+
AnyOfPattern(const Ops &...ops) : AnyOfPattern(Indexes{}, ops...) {}
114+
115+
template <typename U> bool match(const U &input) const {
116+
return matchImpl(input, Indexes{});
117+
}
118+
119+
mutable std::variant<const Patterns *..., std::monostate> ref{
120+
std::monostate{}};
121+
};
122+
123+
/// Matches any input of type Expr<T>
124+
/// The indent if this pattern is to be a leaf in multi-operand patterns.
125+
template <typename T> //
126+
struct ExprPattern : public TypePattern<evaluate::Expr<T>> {};
127+
128+
/// Matches evaluate::Expr<T> that contains evaluate::Opreration<OpType>.
129+
template <typename OpType, typename... Ops>
130+
struct OperationPattern : public TypePattern<OpType> {
131+
private:
132+
using Indexes = std::index_sequence_for<Ops...>;
133+
134+
template <typename S, size_t... Is>
135+
bool matchImpl(const S &op, std::index_sequence<Is...>) const {
136+
using TypeS = llvm::remove_cvref_t<S>;
137+
if constexpr (is_operation_v<TypeS>) {
138+
if constexpr (TypeS::operands == Indexes::size()) {
139+
return TypePattern<OpType>::match(op) &&
140+
(std::get<Is>(operands).match(op.template operand<Is>()) && ...);
141+
}
142+
}
143+
return false;
144+
}
145+
146+
std::tuple<const Ops &...> operands;
147+
148+
public:
149+
using MatchType = OpType;
150+
151+
OperationPattern(const Ops &...ops, llvm::type_identity<OpType> = {})
152+
: operands(ops...) {}
153+
154+
template <typename T> bool match(const evaluate::Expr<T> &input) const {
155+
return common::visit(
156+
[&](auto &&s) { return matchImpl(s, Indexes{}); }, deparen(input).u);
157+
}
158+
159+
template <typename U> bool match(const U &input) const {
160+
// Only match Expr<T>
161+
return false;
162+
}
163+
};
164+
165+
template <typename OpType, typename... Ops>
166+
OperationPattern(const Ops &...ops, llvm::type_identity<OpType>)
167+
-> OperationPattern<OpType, Ops...>;
168+
169+
// Namespace-level definitions
170+
171+
template <typename T> using Expr = ExprPattern<T>;
172+
173+
template <typename OpType, typename... Ops>
174+
using Op = OperationPattern<OpType, Ops...>;
175+
176+
template <typename Pattern, typename Input>
177+
bool match(const Pattern &pattern, const Input &input) {
178+
return pattern.match(input);
179+
}
180+
181+
// Specific operation patterns
182+
183+
// -- Add
184+
template <typename Type, typename Op0, typename Op1>
185+
struct Add : public Op<evaluate::Add<Type>, Op0, Op1> {
186+
using Base = Op<evaluate::Add<Type>, Op0, Op1>;
187+
188+
Add(const Op0 &op0, const Op1 &op1) : Base(op0, op1) {}
189+
};
190+
191+
template <typename Type, typename Op0, typename Op1>
192+
Add<Type, Op0, Op1> add(const Op0 &op0, const Op1 &op1) {
193+
return Add<Type, Op0, Op1>(op0, op1);
194+
}
195+
196+
// -- Mul
197+
template <typename Type, typename Op0, typename Op1>
198+
struct Mul : public Op<evaluate::Multiply<Type>, Op0, Op1> {
199+
using Base = Op<evaluate::Multiply<Type>, Op0, Op1>;
200+
201+
Mul(const Op0 &op0, const Op1 &op1) : Base(op0, op1) {}
202+
};
203+
204+
template <typename Type, typename Op0, typename Op1>
205+
Mul<Type, Op0, Op1> mul(const Op0 &op0, const Op1 &op1) {
206+
return Mul<Type, Op0, Op1>(op0, op1);
207+
}
208+
} // namespace match
209+
} // namespace Fortran::evaluate
210+
211+
#endif // FORTRAN_EVALUATE_MATCH_H_

0 commit comments

Comments
 (0)