Skip to content

Commit 97470ba

Browse files
committed
Handling of expression lists.
1 parent 5169cbf commit 97470ba

File tree

6 files changed

+546
-452
lines changed

6 files changed

+546
-452
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ ad_test: @autodiff @print type = {
4040
mul_div_2: (x: double, y: double) -> (r: double) = {
4141
r = x * y / x;
4242
}
43+
44+
mul_add: (x: double, y: double) -> (r: double) = {
45+
r = x * (x + y);
46+
}
4347
}
4448

4549
write_output: (func: std::string, x: double, x_d: double, y: double, y_d: double, ret) = {
@@ -63,4 +67,5 @@ main: () = {
6367
write_output("x / y", x, x_d, y, y_d, ad_test::div_1_diff(x, x_d, y, y_d));
6468
write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_diff(x, x_d, y, y_d));
6569
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_diff(x, x_d, y, y_d));
70+
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_diff(x, x_d, y, y_d));
6671
}

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ diff(x * y * x) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000)
88
diff(x / y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 0.666667, r_d = -0.111111)
99
diff(x / y / y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 0.222222, r_d = -0.185185)
1010
diff(x * y / x) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 3.000000, r_d = 2.000000)
11+
diff(x * (x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)

regression-tests/test-results/pure2-autodiff.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ using mul_div_2_ret = double;
6868

6969
#line 40 "pure2-autodiff.cpp2"
7070
public: [[nodiscard]] static auto mul_div_2(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> mul_div_2_ret;
71+
using mul_add_ret = double;
72+
73+
74+
#line 44 "pure2-autodiff.cpp2"
75+
public: [[nodiscard]] static auto mul_add(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> mul_add_ret;
7176
struct add_1_diff_ret { double r; double r_d; };
7277

7378

@@ -109,17 +114,21 @@ struct mul_div_2_diff_ret { double r; double r_d; };
109114

110115
public: [[nodiscard]] static auto mul_div_2_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> mul_div_2_diff_ret;
111116

117+
struct mul_add_diff_ret { double r; double r_d; };
118+
119+
public: [[nodiscard]] static auto mul_add_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> mul_add_diff_ret;
120+
112121
public: ad_test() = default;
113122
public: ad_test(ad_test const&) = delete; /* No 'that' constructor, suppress copy */
114123
public: auto operator=(ad_test const&) -> void = delete;
115124

116125

117-
#line 43 "pure2-autodiff.cpp2"
126+
#line 47 "pure2-autodiff.cpp2"
118127
};
119128

120129
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void;
121130

122-
#line 49 "pure2-autodiff.cpp2"
131+
#line 53 "pure2-autodiff.cpp2"
123132
auto main() -> int;
124133

125134
//=== Cpp2 function definitions =================================================
@@ -196,6 +205,13 @@ auto main() -> int;
196205
r.construct(x * y / CPP2_ASSERT_NOT_ZERO(CPP2_TYPEOF(y),x));
197206
return std::move(r.value()); }
198207

208+
#line 44 "pure2-autodiff.cpp2"
209+
[[nodiscard]] auto ad_test::mul_add(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> mul_add_ret{
210+
cpp2::impl::deferred_init<double> r;
211+
#line 45 "pure2-autodiff.cpp2"
212+
r.construct(x * (x + y));
213+
return std::move(r.value()); }
214+
199215
[[nodiscard]] auto ad_test::add_1_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> add_1_diff_ret{
200216
double r {0.0};
201217
double r_d {0.0};r_d = x_d + y_d;r = x + y;
@@ -252,13 +268,20 @@ auto temp_1_d {x * y_d + y * x_d};
252268
auto temp_1 {x * y}; r_d = cpp2::move(temp_1_d) / CPP2_ASSERT_NOT_ZERO(CPP2_TYPEOF(cpp2::move(temp_1_d)),x) - temp_1 * x_d / CPP2_ASSERT_NOT_ZERO(CPP2_TYPEOF(x_d),(x * x));r = cpp2::move(temp_1) / CPP2_ASSERT_NOT_ZERO(CPP2_TYPEOF(cpp2::move(temp_1)),x);
253269
return { std::move(r), std::move(r_d) };
254270
}
271+
[[nodiscard]] auto ad_test::mul_add_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> mul_add_diff_ret{
272+
double r {0.0};
273+
double r_d {0.0};
274+
auto temp_1_d {x_d + y_d};
275+
auto temp_1 {x + y}; r_d = x * cpp2::move(temp_1_d) + temp_1 * x_d;r = x * cpp2::move(temp_1);
276+
return { std::move(r), std::move(r_d) };
277+
}
255278

256-
#line 45 "pure2-autodiff.cpp2"
279+
#line 49 "pure2-autodiff.cpp2"
257280
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void{
258281
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", x_d = " + cpp2::to_string(x_d) + ", y = " + cpp2::to_string(y) + ", y_d = " + cpp2::to_string(y_d) + ") = (r = " + cpp2::to_string(ret.r) + ", r_d = " + cpp2::to_string(ret.r_d) + ")" << std::endl;
259282
}
260283

261-
#line 49 "pure2-autodiff.cpp2"
284+
#line 53 "pure2-autodiff.cpp2"
262285
auto main() -> int{
263286

264287
double x {2.0};
@@ -275,6 +298,7 @@ auto main() -> int{
275298
write_output("x * y * x", x, x_d, y, y_d, ad_test::mul_2_diff(x, x_d, y, y_d));
276299
write_output("x / y", x, x_d, y, y_d, ad_test::div_1_diff(x, x_d, y, y_d));
277300
write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_diff(x, x_d, y, y_d));
278-
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_diff(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
301+
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_diff(x, x_d, y, y_d));
302+
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_diff(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
279303
}
280304

regression-tests/test-results/pure2-autodiff.cpp2.output

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ ad_test:/* @autodiff @print */ type =
9292
return;
9393
}
9494

95+
mul_add:(
96+
in x: double,
97+
in y: double,
98+
) -> (out r: double, ) =
99+
{
100+
r = x * (x + y);
101+
return;
102+
}
103+
95104
add_1_diff:(
96105
in x: double,
97106
in x_d: double,
@@ -247,6 +256,23 @@ ad_test:/* @autodiff @print */ type =
247256
r = temp_1 / x;
248257
return;
249258
}
259+
260+
mul_add_diff:(
261+
in x: double,
262+
in x_d: double,
263+
in y: double,
264+
in y_d: double,
265+
) -> (
266+
out r: double = 0.0,
267+
out r_d: double = 0.0,
268+
) =
269+
{
270+
temp_1_d: _ = x_d + y_d;
271+
temp_1: _ = x + y;
272+
r_d = x * temp_1_d + temp_1 * x_d;
273+
r = x * temp_1;
274+
return;
275+
}
250276
}
251277
ok (all Cpp2, passes safety checks)
252278

0 commit comments

Comments
 (0)