Skip to content

Commit f0471bc

Browse files
kparzysztblah
andauthored
[flang][Evaluate] Implement rewriting framework for evaluate::Expr (#153037)
The structure of evaluate::Expr is highly customized for the specific operation or entity that it represents. The different cases are expressed with different types, which makes the traversal and modifications somewhat complicated. There exists a framework for read-only traversal (traverse.h), but there is nothing that helps with modifying evaluate::Expr. It's rare that evaluate::Expr needs to be modified, but for the cases where it needs to be, this code will make it easier. --------- Co-authored-by: Tom Eccles <[email protected]>
1 parent 3746bd2 commit f0471bc

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//===-- include/flang/Evaluate/rewrite.h ------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef FORTRAN_EVALUATE_REWRITE_H_
9+
#define FORTRAN_EVALUATE_REWRITE_H_
10+
11+
#include "flang/Common/visit.h"
12+
#include "flang/Evaluate/expression.h"
13+
#include "flang/Support/Fortran.h"
14+
#include "llvm/ADT/STLExtras.h"
15+
16+
#include <tuple>
17+
#include <type_traits>
18+
#include <utility>
19+
#include <variant>
20+
21+
namespace Fortran::evaluate {
22+
namespace rewrite {
23+
namespace detail {
24+
template <typename, typename = void> //
25+
struct IsOperation {
26+
static constexpr bool value{false};
27+
};
28+
29+
template <typename T>
30+
struct IsOperation<T, std::void_t<decltype(T::operands)>> {
31+
static constexpr bool value{true};
32+
};
33+
} // namespace detail
34+
35+
template <typename T>
36+
constexpr bool is_operation_v{detail::IsOperation<T>::value};
37+
38+
/// Individual Expr<T> rewriter that simply constructs an expression that is
39+
/// identical to the input. This is a suitable base class for all user-defined
40+
/// rewriters.
41+
struct Identity {
42+
template <typename T, typename U>
43+
Expr<T> operator()(Expr<T> &&x, const U &op) {
44+
return std::move(x);
45+
}
46+
};
47+
48+
/// Bottom-up Expr<T> rewriter.
49+
///
50+
/// The Mutator traverses and reconstructs given Expr<T>. Going bottom-up,
51+
/// whenever the traversal visits a sub-node of type Expr<U> (for some U),
52+
/// it will invoke the user-provided rewriter via the () operator.
53+
///
54+
/// If x is of type Expr<U>, it will call (in pseudo-code):
55+
/// rewriter_(x, active_member_of(x.u))
56+
/// The second parameter is there to make it easier to overload the () operator
57+
/// for specific operations in Expr<...>.
58+
///
59+
/// The user rewriter is only invoked for Expr<U>, not for Operation, nor any
60+
/// other subobject.
61+
template <typename Rewriter> struct Mutator {
62+
Mutator(Rewriter &rewriter) : rewriter_(rewriter) {}
63+
64+
template <typename T, typename U = llvm::remove_cvref_t<T>>
65+
U operator()(T &&x) {
66+
if constexpr (std::is_lvalue_reference_v<T>) {
67+
return Mutate(U(x));
68+
} else {
69+
return Mutate(std::move(x));
70+
}
71+
}
72+
73+
private:
74+
template <typename T> struct LambdaWithRvalueCapture {
75+
LambdaWithRvalueCapture(Rewriter &r, Expr<T> &&c)
76+
: rewriter_(r), capture_(std::move(c)) {}
77+
template <typename S> Expr<T> operator()(const S &s) {
78+
return rewriter_(std::move(capture_), s);
79+
}
80+
81+
private:
82+
Rewriter &rewriter_;
83+
Expr<T> &&capture_;
84+
};
85+
86+
template <typename T, typename = std::enable_if_t<!is_operation_v<T>>>
87+
T Mutate(T &&x) const {
88+
return std::move(x);
89+
}
90+
91+
template <typename D, typename = std::enable_if_t<is_operation_v<D>>>
92+
D Mutate(D &&op, std::make_index_sequence<D::operands> t = {}) const {
93+
return MutateOp(std::move(op), t);
94+
}
95+
96+
template <typename T> //
97+
Expr<T> Mutate(Expr<T> &&x) const {
98+
// First construct the new expression with the rewritten op.
99+
Expr<T> n{common::visit(
100+
[&](auto &&s) { //
101+
return Expr<T>(Mutate(std::move(s)));
102+
},
103+
std::move(x.u))};
104+
// Return the rewritten expression. The second visit is to make sure
105+
// that the second argument in the call to the rewriter is a part of
106+
// the Expr<T> passed to it.
107+
return common::visit(
108+
LambdaWithRvalueCapture<T>(rewriter_, std::move(n)), std::move(n.u));
109+
}
110+
111+
template <typename... Ts>
112+
std::variant<Ts...> Mutate(std::variant<Ts...> &&u) const {
113+
return common::visit(
114+
[this](auto &&s) { return Mutate(std::move(s)); }, std::move(u));
115+
}
116+
117+
template <typename... Ts>
118+
std::tuple<Ts...> Mutate(std::tuple<Ts...> &&t) const {
119+
return MutateTuple(std::move(t), std::index_sequence_for<Ts...>{});
120+
}
121+
122+
template <typename... Ts, size_t... Is>
123+
std::tuple<Ts...> MutateTuple(
124+
std::tuple<Ts...> &&t, std::index_sequence<Is...>) const {
125+
return std::make_tuple(Mutate(std::move(std::get<Is>(t))...));
126+
}
127+
128+
template <typename D, size_t... Is>
129+
D MutateOp(D &&op, std::index_sequence<Is...>) const {
130+
return D(Mutate(std::move(op.template operand<Is>()))...);
131+
}
132+
133+
template <typename T, size_t... Is>
134+
Extremum<T> MutateOp(Extremum<T> &&op, std::index_sequence<Is...>) const {
135+
return Extremum<T>(
136+
op.ordering, Mutate(std::move(op.template operand<Is>()))...);
137+
}
138+
139+
template <int K, size_t... Is>
140+
ComplexComponent<K> MutateOp(
141+
ComplexComponent<K> &&op, std::index_sequence<Is...>) const {
142+
return ComplexComponent<K>(
143+
op.isImaginaryPart, Mutate(std::move(op.template operand<Is>()))...);
144+
}
145+
146+
template <int K, size_t... Is>
147+
LogicalOperation<K> MutateOp(
148+
LogicalOperation<K> &&op, std::index_sequence<Is...>) const {
149+
return LogicalOperation<K>(
150+
op.logicalOperator, Mutate(std::move(op.template operand<Is>()))...);
151+
}
152+
153+
Rewriter &rewriter_;
154+
};
155+
156+
template <typename Rewriter> Mutator(Rewriter &) -> Mutator<Rewriter>;
157+
} // namespace rewrite
158+
} // namespace Fortran::evaluate
159+
160+
#endif // FORTRAN_EVALUATE_REWRITE_H_

0 commit comments

Comments
 (0)