|
| 1 | +//===-- include/flang/Evaluate/match.h --------------------------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +#ifndef FORTRAN_EVALUATE_MATCH_H_ |
| 9 | +#define FORTRAN_EVALUATE_MATCH_H_ |
| 10 | + |
| 11 | +#include "flang/Common/visit.h" |
| 12 | +#include "flang/Evaluate/expression.h" |
| 13 | +#include "llvm/ADT/STLExtras.h" |
| 14 | + |
| 15 | +#include <tuple> |
| 16 | +#include <type_traits> |
| 17 | +#include <utility> |
| 18 | +#include <variant> |
| 19 | + |
| 20 | +namespace Fortran::evaluate { |
| 21 | +namespace match { |
| 22 | +namespace detail { |
| 23 | +template <typename, typename = void> // |
| 24 | +struct IsOperation { |
| 25 | + static constexpr bool value{false}; |
| 26 | +}; |
| 27 | + |
| 28 | +template <typename T> |
| 29 | +struct IsOperation<T, std::void_t<decltype(T::operands)>> { |
| 30 | + static constexpr bool value{true}; |
| 31 | +}; |
| 32 | +} // namespace detail |
| 33 | + |
| 34 | +template <typename T> |
| 35 | +constexpr bool is_operation_v{detail::IsOperation<T>::value}; |
| 36 | + |
| 37 | +template <typename T> |
| 38 | +const evaluate::Expr<T> &deparen(const evaluate::Expr<T> &x) { |
| 39 | + if (auto *parens{std::get_if<evaluate::Parentheses<T>>(&x.u)}) { |
| 40 | + return deparen(parens->template operand<0>()); |
| 41 | + } else { |
| 42 | + return x; |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +// Expr<T> matchers (patterns) |
| 47 | +// |
| 48 | +// Each pattern should implement |
| 49 | +// bool match(const U &input) const |
| 50 | +// member function that returns `true` when the match was successful, |
| 51 | +// and `false` otherwise. |
| 52 | +// |
| 53 | +// Patterns are intended to be composable, i.e. a pattern can take operands |
| 54 | +// which themselves are patterns. This composition is expected to match if |
| 55 | +// the root pattern and all its operands match given input. |
| 56 | + |
| 57 | +/// Matches any input as long as it has the expected type `MatchType`. |
| 58 | +/// Additionally, it sets the member `ref` to the matched input. |
| 59 | +template <typename T> struct TypePattern { |
| 60 | + using MatchType = llvm::remove_cvref_t<T>; |
| 61 | + |
| 62 | + template <typename U> bool match(const U &input) const { |
| 63 | + if constexpr (std::is_same_v<MatchType, U>) { |
| 64 | + ref = &input; |
| 65 | + return true; |
| 66 | + } else { |
| 67 | + return false; |
| 68 | + } |
| 69 | + } |
| 70 | + |
| 71 | + mutable const MatchType *ref{nullptr}; |
| 72 | +}; |
| 73 | + |
| 74 | +/// Matches one of the patterns provided as template arguments. All of these |
| 75 | +/// patterns should have the same number of operands, i.e. they all should |
| 76 | +/// try to match input expression with the same number of children, i.e. |
| 77 | +/// AnyOfPattern<SomeBinaryOp, OtherBinaryOp> is ok, whereas |
| 78 | +/// AnyOfPattern<SomeBinaryOp, SomeTernaryOp> is not. |
| 79 | +template <typename... Patterns> struct AnyOfPattern { |
| 80 | + static_assert(sizeof...(Patterns) != 0); |
| 81 | + |
| 82 | +private: |
| 83 | + using PatternTuple = std::tuple<Patterns...>; |
| 84 | + |
| 85 | + template <size_t I> |
| 86 | + using Pattern = typename std::tuple_element<I, PatternTuple>::type; |
| 87 | + |
| 88 | + template <size_t... Is, typename... Ops> |
| 89 | + AnyOfPattern(std::index_sequence<Is...>, const Ops &...ops) |
| 90 | + : patterns(std::make_tuple(Pattern<Is>(ops...)...)) {} |
| 91 | + |
| 92 | + template <typename P, typename U> |
| 93 | + bool matchOne(const P &pattern, const U &input) const { |
| 94 | + if (pattern.match(input)) { |
| 95 | + ref = &pattern; |
| 96 | + return true; |
| 97 | + } |
| 98 | + return false; |
| 99 | + } |
| 100 | + |
| 101 | + template <typename U, size_t... Is> |
| 102 | + bool matchImpl(const U &input, std::index_sequence<Is...>) const { |
| 103 | + return (matchOne(std::get<Is>(patterns), input) || ...); |
| 104 | + } |
| 105 | + |
| 106 | + PatternTuple patterns; |
| 107 | + |
| 108 | +public: |
| 109 | + using Indexes = std::index_sequence_for<Patterns...>; |
| 110 | + using MatchTypes = std::tuple<typename Patterns::MatchType...>; |
| 111 | + |
| 112 | + template <typename... Ops> |
| 113 | + AnyOfPattern(const Ops &...ops) : AnyOfPattern(Indexes{}, ops...) {} |
| 114 | + |
| 115 | + template <typename U> bool match(const U &input) const { |
| 116 | + return matchImpl(input, Indexes{}); |
| 117 | + } |
| 118 | + |
| 119 | + mutable std::variant<const Patterns *..., std::monostate> ref{ |
| 120 | + std::monostate{}}; |
| 121 | +}; |
| 122 | + |
| 123 | +/// Matches any input of type Expr<T> |
| 124 | +/// The indent if this pattern is to be a leaf in multi-operand patterns. |
| 125 | +template <typename T> // |
| 126 | +struct ExprPattern : public TypePattern<evaluate::Expr<T>> {}; |
| 127 | + |
| 128 | +/// Matches evaluate::Expr<T> that contains evaluate::Opreration<OpType>. |
| 129 | +template <typename OpType, typename... Ops> |
| 130 | +struct OperationPattern : public TypePattern<OpType> { |
| 131 | +private: |
| 132 | + using Indexes = std::index_sequence_for<Ops...>; |
| 133 | + |
| 134 | + template <typename S, size_t... Is> |
| 135 | + bool matchImpl(const S &op, std::index_sequence<Is...>) const { |
| 136 | + using TypeS = llvm::remove_cvref_t<S>; |
| 137 | + if constexpr (is_operation_v<TypeS>) { |
| 138 | + if constexpr (TypeS::operands == Indexes::size()) { |
| 139 | + return TypePattern<OpType>::match(op) && |
| 140 | + (std::get<Is>(operands).match(op.template operand<Is>()) && ...); |
| 141 | + } |
| 142 | + } |
| 143 | + return false; |
| 144 | + } |
| 145 | + |
| 146 | + std::tuple<const Ops &...> operands; |
| 147 | + |
| 148 | +public: |
| 149 | + using MatchType = OpType; |
| 150 | + |
| 151 | + OperationPattern(const Ops &...ops, llvm::type_identity<OpType> = {}) |
| 152 | + : operands(ops...) {} |
| 153 | + |
| 154 | + template <typename T> bool match(const evaluate::Expr<T> &input) const { |
| 155 | + return common::visit( |
| 156 | + [&](auto &&s) { return matchImpl(s, Indexes{}); }, deparen(input).u); |
| 157 | + } |
| 158 | + |
| 159 | + template <typename U> bool match(const U &input) const { |
| 160 | + // Only match Expr<T> |
| 161 | + return false; |
| 162 | + } |
| 163 | +}; |
| 164 | + |
| 165 | +template <typename OpType, typename... Ops> |
| 166 | +OperationPattern(const Ops &...ops, llvm::type_identity<OpType>) |
| 167 | + -> OperationPattern<OpType, Ops...>; |
| 168 | + |
| 169 | +// Namespace-level definitions |
| 170 | + |
| 171 | +template <typename T> using Expr = ExprPattern<T>; |
| 172 | + |
| 173 | +template <typename OpType, typename... Ops> |
| 174 | +using Op = OperationPattern<OpType, Ops...>; |
| 175 | + |
| 176 | +template <typename Pattern, typename Input> |
| 177 | +bool match(const Pattern &pattern, const Input &input) { |
| 178 | + return pattern.match(input); |
| 179 | +} |
| 180 | + |
| 181 | +// Specific operation patterns |
| 182 | + |
| 183 | +// -- Add |
| 184 | +template <typename Type, typename Op0, typename Op1> |
| 185 | +struct Add : public Op<evaluate::Add<Type>, Op0, Op1> { |
| 186 | + using Base = Op<evaluate::Add<Type>, Op0, Op1>; |
| 187 | + |
| 188 | + Add(const Op0 &op0, const Op1 &op1) : Base(op0, op1) {} |
| 189 | +}; |
| 190 | + |
| 191 | +template <typename Type, typename Op0, typename Op1> |
| 192 | +Add<Type, Op0, Op1> add(const Op0 &op0, const Op1 &op1) { |
| 193 | + return Add<Type, Op0, Op1>(op0, op1); |
| 194 | +} |
| 195 | + |
| 196 | +// -- Mul |
| 197 | +template <typename Type, typename Op0, typename Op1> |
| 198 | +struct Mul : public Op<evaluate::Multiply<Type>, Op0, Op1> { |
| 199 | + using Base = Op<evaluate::Multiply<Type>, Op0, Op1>; |
| 200 | + |
| 201 | + Mul(const Op0 &op0, const Op1 &op1) : Base(op0, op1) {} |
| 202 | +}; |
| 203 | + |
| 204 | +template <typename Type, typename Op0, typename Op1> |
| 205 | +Mul<Type, Op0, Op1> mul(const Op0 &op0, const Op1 &op1) { |
| 206 | + return Mul<Type, Op0, Op1>(op0, op1); |
| 207 | +} |
| 208 | +} // namespace match |
| 209 | +} // namespace Fortran::evaluate |
| 210 | + |
| 211 | +#endif // FORTRAN_EVALUATE_MATCH_H_ |
0 commit comments