diff --git a/flang/include/flang/Evaluate/match.h b/flang/include/flang/Evaluate/match.h new file mode 100644 index 0000000000000..79da40f7c1338 --- /dev/null +++ b/flang/include/flang/Evaluate/match.h @@ -0,0 +1,211 @@ +//===-- include/flang/Evaluate/match.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef FORTRAN_EVALUATE_MATCH_H_ +#define FORTRAN_EVALUATE_MATCH_H_ + +#include "flang/Common/visit.h" +#include "flang/Evaluate/expression.h" +#include "llvm/ADT/STLExtras.h" + +#include +#include +#include +#include + +namespace Fortran::evaluate { +namespace match { +namespace detail { +template // +struct IsOperation { + static constexpr bool value{false}; +}; + +template +struct IsOperation> { + static constexpr bool value{true}; +}; +} // namespace detail + +template +constexpr bool is_operation_v{detail::IsOperation::value}; + +template +const evaluate::Expr &deparen(const evaluate::Expr &x) { + if (auto *parens{std::get_if>(&x.u)}) { + return deparen(parens->template operand<0>()); + } else { + return x; + } +} + +// Expr matchers (patterns) +// +// Each pattern should implement +// bool match(const U &input) const +// member function that returns `true` when the match was successful, +// and `false` otherwise. +// +// Patterns are intended to be composable, i.e. a pattern can take operands +// which themselves are patterns. This composition is expected to match if +// the root pattern and all its operands match given input. + +/// Matches any input as long as it has the expected type `MatchType`. +/// Additionally, it sets the member `ref` to the matched input. +template struct TypePattern { + using MatchType = llvm::remove_cvref_t; + + template bool match(const U &input) const { + if constexpr (std::is_same_v) { + ref = &input; + return true; + } else { + return false; + } + } + + mutable const MatchType *ref{nullptr}; +}; + +/// Matches one of the patterns provided as template arguments. All of these +/// patterns should have the same number of operands, i.e. they all should +/// try to match input expression with the same number of children, i.e. +/// AnyOfPattern is ok, whereas +/// AnyOfPattern is not. +template struct AnyOfPattern { + static_assert(sizeof...(Patterns) != 0); + +private: + using PatternTuple = std::tuple; + + template + using Pattern = typename std::tuple_element::type; + + template + AnyOfPattern(std::index_sequence, const Ops &...ops) + : patterns(std::make_tuple(Pattern(ops...)...)) {} + + template + bool matchOne(const P &pattern, const U &input) const { + if (pattern.match(input)) { + ref = &pattern; + return true; + } + return false; + } + + template + bool matchImpl(const U &input, std::index_sequence) const { + return (matchOne(std::get(patterns), input) || ...); + } + + PatternTuple patterns; + +public: + using Indexes = std::index_sequence_for; + using MatchTypes = std::tuple; + + template + AnyOfPattern(const Ops &...ops) : AnyOfPattern(Indexes{}, ops...) {} + + template bool match(const U &input) const { + return matchImpl(input, Indexes{}); + } + + mutable std::variant ref{ + std::monostate{}}; +}; + +/// Matches any input of type Expr +/// The indent if this pattern is to be a leaf in multi-operand patterns. +template // +struct ExprPattern : public TypePattern> {}; + +/// Matches evaluate::Expr that contains evaluate::Opreration. +template +struct OperationPattern : public TypePattern { +private: + using Indexes = std::index_sequence_for; + + template + bool matchImpl(const S &op, std::index_sequence) const { + using TypeS = llvm::remove_cvref_t; + if constexpr (is_operation_v) { + if constexpr (TypeS::operands == Indexes::size()) { + return TypePattern::match(op) && + (std::get(operands).match(op.template operand()) && ...); + } + } + return false; + } + + std::tuple operands; + +public: + using MatchType = OpType; + + OperationPattern(const Ops &...ops, llvm::type_identity = {}) + : operands(ops...) {} + + template bool match(const evaluate::Expr &input) const { + return common::visit( + [&](auto &&s) { return matchImpl(s, Indexes{}); }, deparen(input).u); + } + + template bool match(const U &input) const { + // Only match Expr + return false; + } +}; + +template +OperationPattern(const Ops &...ops, llvm::type_identity) + -> OperationPattern; + +// Namespace-level definitions + +template using Expr = ExprPattern; + +template +using Op = OperationPattern; + +template +bool match(const Pattern &pattern, const Input &input) { + return pattern.match(input); +} + +// Specific operation patterns + +// -- Add +template +struct Add : public Op, Op0, Op1> { + using Base = Op, Op0, Op1>; + + Add(const Op0 &op0, const Op1 &op1) : Base(op0, op1) {} +}; + +template +Add add(const Op0 &op0, const Op1 &op1) { + return Add(op0, op1); +} + +// -- Mul +template +struct Mul : public Op, Op0, Op1> { + using Base = Op, Op0, Op1>; + + Mul(const Op0 &op0, const Op1 &op1) : Base(op0, op1) {} +}; + +template +Mul mul(const Op0 &op0, const Op1 &op1) { + return Mul(op0, op1); +} +} // namespace match +} // namespace Fortran::evaluate + +#endif // FORTRAN_EVALUATE_MATCH_H_