Skip to content
211 changes: 211 additions & 0 deletions flang/include/flang/Evaluate/match.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
//===-- include/flang/Evaluate/match.h --------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef FORTRAN_EVALUATE_MATCH_H_
#define FORTRAN_EVALUATE_MATCH_H_

#include "flang/Common/visit.h"
#include "flang/Evaluate/expression.h"
#include "llvm/ADT/STLExtras.h"

#include <tuple>
#include <type_traits>
#include <utility>
#include <variant>

namespace Fortran::evaluate {
namespace match {
namespace detail {
template <typename, typename = void> //
struct IsOperation {
static constexpr bool value{false};
};

template <typename T>
struct IsOperation<T, std::void_t<decltype(T::operands)>> {
static constexpr bool value{true};
};
} // namespace detail

template <typename T>
constexpr bool is_operation_v{detail::IsOperation<T>::value};

template <typename T>
const evaluate::Expr<T> &deparen(const evaluate::Expr<T> &x) {
if (auto *parens{std::get_if<evaluate::Parentheses<T>>(&x.u)}) {
return deparen(parens->template operand<0>());
} else {
return x;
}
}

// Expr<T> matchers (patterns)
//
// Each pattern should implement
// bool match(const U &input) const
// member function that returns `true` when the match was successful,
// and `false` otherwise.
//
// Patterns are intended to be composable, i.e. a pattern can take operands
// which themselves are patterns. This composition is expected to match if
// the root pattern and all its operands match given input.

/// Matches any input as long as it has the expected type `MatchType`.
/// Additionally, it sets the member `ref` to the matched input.
template <typename T> struct TypePattern {
using MatchType = llvm::remove_cvref_t<T>;

template <typename U> bool match(const U &input) const {
if constexpr (std::is_same_v<MatchType, U>) {
ref = &input;
return true;
} else {
return false;
}
}

mutable const MatchType *ref{nullptr};
};

/// Matches one of the patterns provided as template arguments. All of these
/// patterns should have the same number of operands, i.e. they all should
/// try to match input expression with the same number of children, i.e.
/// AnyOfPattern<SomeBinaryOp, OtherBinaryOp> is ok, whereas
/// AnyOfPattern<SomeBinaryOp, SomeTernaryOp> is not.
template <typename... Patterns> struct AnyOfPattern {
static_assert(sizeof...(Patterns) != 0);

private:
using PatternTuple = std::tuple<Patterns...>;

template <size_t I>
using Pattern = typename std::tuple_element<I, PatternTuple>::type;

template <size_t... Is, typename... Ops>
AnyOfPattern(std::index_sequence<Is...>, const Ops &...ops)
: patterns(std::make_tuple(Pattern<Is>(ops...)...)) {}

template <typename P, typename U>
bool matchOne(const P &pattern, const U &input) const {
if (pattern.match(input)) {
ref = &pattern;
return true;
}
return false;
}

template <typename U, size_t... Is>
bool matchImpl(const U &input, std::index_sequence<Is...>) const {
return (matchOne(std::get<Is>(patterns), input) || ...);
}

PatternTuple patterns;

public:
using Indexes = std::index_sequence_for<Patterns...>;
using MatchTypes = std::tuple<typename Patterns::MatchType...>;

template <typename... Ops>
AnyOfPattern(const Ops &...ops) : AnyOfPattern(Indexes{}, ops...) {}

template <typename U> bool match(const U &input) const {
return matchImpl(input, Indexes{});
}

mutable std::variant<const Patterns *..., std::monostate> ref{
std::monostate{}};
};

/// Matches any input of type Expr<T>
/// The indent if this pattern is to be a leaf in multi-operand patterns.
template <typename T> //
struct ExprPattern : public TypePattern<evaluate::Expr<T>> {};

/// Matches evaluate::Expr<T> that contains evaluate::Opreration<OpType>.
template <typename OpType, typename... Ops>
struct OperationPattern : public TypePattern<OpType> {
private:
using Indexes = std::index_sequence_for<Ops...>;

template <typename S, size_t... Is>
bool matchImpl(const S &op, std::index_sequence<Is...>) const {
using TypeS = llvm::remove_cvref_t<S>;
if constexpr (is_operation_v<TypeS>) {
if constexpr (TypeS::operands == Indexes::size()) {
return TypePattern<OpType>::match(op) &&
(std::get<Is>(operands).match(op.template operand<Is>()) && ...);
}
}
return false;
}

std::tuple<const Ops &...> operands;

public:
using MatchType = OpType;

OperationPattern(const Ops &...ops, llvm::type_identity<OpType> = {})
: operands(ops...) {}

template <typename T> bool match(const evaluate::Expr<T> &input) const {
return common::visit(
[&](auto &&s) { return matchImpl(s, Indexes{}); }, deparen(input).u);
}

template <typename U> bool match(const U &input) const {
// Only match Expr<T>
return false;
}
};

template <typename OpType, typename... Ops>
OperationPattern(const Ops &...ops, llvm::type_identity<OpType>)
-> OperationPattern<OpType, Ops...>;

// Namespace-level definitions

template <typename T> using Expr = ExprPattern<T>;

template <typename OpType, typename... Ops>
using Op = OperationPattern<OpType, Ops...>;

template <typename Pattern, typename Input>
bool match(const Pattern &pattern, const Input &input) {
return pattern.match(input);
}

// Specific operation patterns

// -- Add
template <typename Type, typename Op0, typename Op1>
struct Add : public Op<evaluate::Add<Type>, Op0, Op1> {
using Base = Op<evaluate::Add<Type>, Op0, Op1>;

Add(const Op0 &op0, const Op1 &op1) : Base(op0, op1) {}
};

template <typename Type, typename Op0, typename Op1>
Add<Type, Op0, Op1> add(const Op0 &op0, const Op1 &op1) {
return Add<Type, Op0, Op1>(op0, op1);
}

// -- Mul
template <typename Type, typename Op0, typename Op1>
struct Mul : public Op<evaluate::Multiply<Type>, Op0, Op1> {
using Base = Op<evaluate::Multiply<Type>, Op0, Op1>;

Mul(const Op0 &op0, const Op1 &op1) : Base(op0, op1) {}
};

template <typename Type, typename Op0, typename Op1>
Mul<Type, Op0, Op1> mul(const Op0 &op0, const Op1 &op1) {
return Mul<Type, Op0, Op1>(op0, op1);
}
} // namespace match
} // namespace Fortran::evaluate

#endif // FORTRAN_EVALUATE_MATCH_H_
160 changes: 160 additions & 0 deletions flang/include/flang/Evaluate/rewrite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
//===-- include/flang/Evaluate/rewrite.h ------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef FORTRAN_EVALUATE_REWRITE_H_
#define FORTRAN_EVALUATE_REWRITE_H_

#include "flang/Common/visit.h"
#include "flang/Evaluate/expression.h"
#include "flang/Support/Fortran.h"
#include "llvm/ADT/STLExtras.h"

#include <tuple>
#include <type_traits>
#include <utility>
#include <variant>

namespace Fortran::evaluate {
namespace rewrite {
namespace detail {
template <typename, typename = void> //
struct IsOperation {
static constexpr bool value{false};
};

template <typename T>
struct IsOperation<T, std::void_t<decltype(T::operands)>> {
static constexpr bool value{true};
};
} // namespace detail

template <typename T>
constexpr bool is_operation_v{detail::IsOperation<T>::value};

/// Individual Expr<T> rewriter that simply constructs an expression that is
/// identical to the input. This is a suitable base class for all user-defined
/// rewriters.
struct Identity {
template <typename T, typename U>
Expr<T> operator()(Expr<T> &&x, const U &op) {
return std::move(x);
}
};

/// Bottom-up Expr<T> rewriter.
///
/// The Mutator traverses and reconstructs given Expr<T>. Going bottom-up,
/// whenever the traversal visits a sub-node of type Expr<U> (for some U),
/// it will invoke the user-provided rewriter via the () operator.
///
/// If x is of type Expr<U>, it will call (in pseudo-code):
/// rewriter_(x, active_member_of(x.u))
/// The second parameter is there to make it easier to overload the () operator
/// for specific operations in Expr<...>.
///
/// The user rewriter is only invoked for Expr<U>, not for Operation, nor any
/// other subobject.
template <typename Rewriter> struct Mutator {
Mutator(Rewriter &rewriter) : rewriter_(rewriter) {}

template <typename T, typename U = llvm::remove_cvref_t<T>>
U operator()(T &&x) {
if constexpr (std::is_lvalue_reference_v<T>) {
return Mutate(U(x));
} else {
return Mutate(std::move(x));
}
}

private:
template <typename T> struct LambdaWithRvalueCapture {
LambdaWithRvalueCapture(Rewriter &r, Expr<T> &&c)
: rewriter_(r), capture_(std::move(c)) {}
template <typename S> Expr<T> operator()(const S &s) {
return rewriter_(std::move(capture_), s);
}

private:
Rewriter &rewriter_;
Expr<T> &&capture_;
};

template <typename T, typename = std::enable_if_t<!is_operation_v<T>>>
T Mutate(T &&x) const {
return std::move(x);
}

template <typename D, typename = std::enable_if_t<is_operation_v<D>>>
D Mutate(D &&op, std::make_index_sequence<D::operands> t = {}) const {
return MutateOp(std::move(op), t);
}

template <typename T> //
Expr<T> Mutate(Expr<T> &&x) const {
// First construct the new expression with the rewritten op.
Expr<T> n{common::visit(
[&](auto &&s) { //
return Expr<T>(Mutate(std::move(s)));
},
std::move(x.u))};
// Return the rewritten expression. The second visit it to make sure
// that the second argument in the call to the rewriter is a part of
// the Expr<T> passed to it.
return common::visit(
LambdaWithRvalueCapture<T>(rewriter_, std::move(n)), std::move(n.u));
}

template <typename... Ts>
std::variant<Ts...> Mutate(std::variant<Ts...> &&u) const {
return common::visit(
[this](auto &&s) { return Mutate(std::move(s)); }, std::move(u));
}

template <typename... Ts>
std::tuple<Ts...> Mutate(std::tuple<Ts...> &&t) const {
return MutateTuple(std::move(t), std::index_sequence_for<Ts...>{});
}

template <typename... Ts, size_t... Is>
std::tuple<Ts...> MutateTuple(
std::tuple<Ts...> &&t, std::index_sequence<Is...>) const {
return std::make_tuple(Mutate(std::move(std::get<Is>(t))...));
}

template <typename D, size_t... Is>
D MutateOp(D &&op, std::index_sequence<Is...>) const {
return D(Mutate(std::move(op.template operand<Is>()))...);
}

template <typename T, size_t... Is>
Extremum<T> MutateOp(Extremum<T> &&op, std::index_sequence<Is...>) const {
return Extremum<T>(
op.ordering, Mutate(std::move(op.template operand<Is>()))...);
}

template <int K, size_t... Is>
ComplexComponent<K> MutateOp(
ComplexComponent<K> &&op, std::index_sequence<Is...>) const {
return ComplexComponent<K>(
op.isImaginaryPart, Mutate(std::move(op.template operand<Is>()))...);
}

template <int K, size_t... Is>
LogicalOperation<K> MutateOp(
LogicalOperation<K> &&op, std::index_sequence<Is...>) const {
return LogicalOperation<K>(
op.logicalOperator, Mutate(std::move(op.template operand<Is>()))...);
}

Rewriter &rewriter_;
};

template <typename Rewriter> Mutator(Rewriter &) -> Mutator<Rewriter>;
} // namespace rewrite
} // namespace Fortran::evaluate

#endif // FORTRAN_EVALUATE_REWRITE_H_
8 changes: 8 additions & 0 deletions flang/include/flang/Semantics/openmp-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,21 @@

#include <optional>
#include <string>
#include <type_traits>
#include <utility>

namespace Fortran::semantics {
class SemanticsContext;
class Symbol;

// Add this namespace to avoid potential conflicts
namespace omp {
template <typename T, typename U = std::remove_const_t<T>> U AsRvalue(T &t) {
return U(t);
}

template <typename T> T &&AsRvalue(T &&t) { return std::move(t); }

// There is no consistent way to get the source of an ActionStmt, but there
// is "source" in Statement<T>. This structure keeps the ActionStmt with the
// extracted source for further use.
Expand Down
Loading