diff --git a/flang/include/flang/Evaluate/rewrite.h b/flang/include/flang/Evaluate/rewrite.h new file mode 100644 index 0000000000000..50259cc0959f4 --- /dev/null +++ b/flang/include/flang/Evaluate/rewrite.h @@ -0,0 +1,160 @@ +//===-- include/flang/Evaluate/rewrite.h ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef FORTRAN_EVALUATE_REWRITE_H_ +#define FORTRAN_EVALUATE_REWRITE_H_ + +#include "flang/Common/visit.h" +#include "flang/Evaluate/expression.h" +#include "flang/Support/Fortran.h" +#include "llvm/ADT/STLExtras.h" + +#include +#include +#include +#include + +namespace Fortran::evaluate { +namespace rewrite { +namespace detail { +template // +struct IsOperation { + static constexpr bool value{false}; +}; + +template +struct IsOperation> { + static constexpr bool value{true}; +}; +} // namespace detail + +template +constexpr bool is_operation_v{detail::IsOperation::value}; + +/// Individual Expr rewriter that simply constructs an expression that is +/// identical to the input. This is a suitable base class for all user-defined +/// rewriters. +struct Identity { + template + Expr operator()(Expr &&x, const U &op) { + return std::move(x); + } +}; + +/// Bottom-up Expr rewriter. +/// +/// The Mutator traverses and reconstructs given Expr. Going bottom-up, +/// whenever the traversal visits a sub-node of type Expr (for some U), +/// it will invoke the user-provided rewriter via the () operator. +/// +/// If x is of type Expr, it will call (in pseudo-code): +/// rewriter_(x, active_member_of(x.u)) +/// The second parameter is there to make it easier to overload the () operator +/// for specific operations in Expr<...>. +/// +/// The user rewriter is only invoked for Expr, not for Operation, nor any +/// other subobject. +template struct Mutator { + Mutator(Rewriter &rewriter) : rewriter_(rewriter) {} + + template > + U operator()(T &&x) { + if constexpr (std::is_lvalue_reference_v) { + return Mutate(U(x)); + } else { + return Mutate(std::move(x)); + } + } + +private: + template struct LambdaWithRvalueCapture { + LambdaWithRvalueCapture(Rewriter &r, Expr &&c) + : rewriter_(r), capture_(std::move(c)) {} + template Expr operator()(const S &s) { + return rewriter_(std::move(capture_), s); + } + + private: + Rewriter &rewriter_; + Expr &&capture_; + }; + + template >> + T Mutate(T &&x) const { + return std::move(x); + } + + template >> + D Mutate(D &&op, std::make_index_sequence t = {}) const { + return MutateOp(std::move(op), t); + } + + template // + Expr Mutate(Expr &&x) const { + // First construct the new expression with the rewritten op. + Expr n{common::visit( + [&](auto &&s) { // + return Expr(Mutate(std::move(s))); + }, + std::move(x.u))}; + // Return the rewritten expression. The second visit is to make sure + // that the second argument in the call to the rewriter is a part of + // the Expr passed to it. + return common::visit( + LambdaWithRvalueCapture(rewriter_, std::move(n)), std::move(n.u)); + } + + template + std::variant Mutate(std::variant &&u) const { + return common::visit( + [this](auto &&s) { return Mutate(std::move(s)); }, std::move(u)); + } + + template + std::tuple Mutate(std::tuple &&t) const { + return MutateTuple(std::move(t), std::index_sequence_for{}); + } + + template + std::tuple MutateTuple( + std::tuple &&t, std::index_sequence) const { + return std::make_tuple(Mutate(std::move(std::get(t))...)); + } + + template + D MutateOp(D &&op, std::index_sequence) const { + return D(Mutate(std::move(op.template operand()))...); + } + + template + Extremum MutateOp(Extremum &&op, std::index_sequence) const { + return Extremum( + op.ordering, Mutate(std::move(op.template operand()))...); + } + + template + ComplexComponent MutateOp( + ComplexComponent &&op, std::index_sequence) const { + return ComplexComponent( + op.isImaginaryPart, Mutate(std::move(op.template operand()))...); + } + + template + LogicalOperation MutateOp( + LogicalOperation &&op, std::index_sequence) const { + return LogicalOperation( + op.logicalOperator, Mutate(std::move(op.template operand()))...); + } + + Rewriter &rewriter_; +}; + +template Mutator(Rewriter &) -> Mutator; +} // namespace rewrite +} // namespace Fortran::evaluate + +#endif // FORTRAN_EVALUATE_REWRITE_H_