Skip to content

Commit b3dc845

Browse files
committed
Handling of function calls.
1 parent 965af06 commit b3dc845

File tree

6 files changed

+842
-608
lines changed

6 files changed

+842
-608
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ ad_test: @autodiff @print type = {
4848
add_mul: (x: double, y: double) -> (r: double) = {
4949
r = x + x * y;
5050
}
51+
52+
func: (x: double, y: double) -> (r: double) = {
53+
r = x + y;
54+
}
55+
56+
func_call: (x: double, y: double) -> (r: double) = {
57+
r = x * func(x, y);
58+
}
5159
}
5260

5361
write_output: (func: std::string, x: double, x_d: double, y: double, y_d: double, ret) = {
@@ -73,4 +81,5 @@ main: () = {
7381
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_diff(x, x_d, y, y_d));
7482
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_diff(x, x_d, y, y_d));
7583
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_diff(x, x_d, y, y_d));
84+
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_diff(x, x_d, y, y_d));
7685
}

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ diff(x / y / y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000)
1010
diff(x * y / x) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 3.000000, r_d = 2.000000)
1111
diff(x * (x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)
1212
diff(x + x * y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 8.000000, r_d = 8.000000)
13+
diff(x * func(x, y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)

regression-tests/test-results/pure2-autodiff.cpp

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ using add_mul_ret = double;
7878

7979
#line 48 "pure2-autodiff.cpp2"
8080
public: [[nodiscard]] static auto add_mul(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> add_mul_ret;
81+
using func_ret = double;
82+
83+
84+
#line 52 "pure2-autodiff.cpp2"
85+
public: [[nodiscard]] static auto func(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_ret;
86+
using func_call_ret = double;
87+
88+
89+
#line 56 "pure2-autodiff.cpp2"
90+
public: [[nodiscard]] static auto func_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_call_ret;
8191
struct add_1_diff_ret { double r; double r_d; };
8292

8393

@@ -127,17 +137,25 @@ struct add_mul_diff_ret { double r; double r_d; };
127137

128138
public: [[nodiscard]] static auto add_mul_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> add_mul_diff_ret;
129139

140+
struct func_diff_ret { double r; double r_d; };
141+
142+
public: [[nodiscard]] static auto func_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> func_diff_ret;
143+
144+
struct func_call_diff_ret { double r; double r_d; };
145+
146+
public: [[nodiscard]] static auto func_call_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> func_call_diff_ret;
147+
130148
public: ad_test() = default;
131149
public: ad_test(ad_test const&) = delete; /* No 'that' constructor, suppress copy */
132150
public: auto operator=(ad_test const&) -> void = delete;
133151

134152

135-
#line 51 "pure2-autodiff.cpp2"
153+
#line 59 "pure2-autodiff.cpp2"
136154
};
137155

138156
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void;
139157

140-
#line 57 "pure2-autodiff.cpp2"
158+
#line 65 "pure2-autodiff.cpp2"
141159
auto main() -> int;
142160

143161
//=== Cpp2 function definitions =================================================
@@ -228,6 +246,20 @@ auto main() -> int;
228246
r.construct(x + x * y);
229247
return std::move(r.value()); }
230248

249+
#line 52 "pure2-autodiff.cpp2"
250+
[[nodiscard]] auto ad_test::func(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_ret{
251+
cpp2::impl::deferred_init<double> r;
252+
#line 53 "pure2-autodiff.cpp2"
253+
r.construct(x + y);
254+
return std::move(r.value()); }
255+
256+
#line 56 "pure2-autodiff.cpp2"
257+
[[nodiscard]] auto ad_test::func_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_call_ret{
258+
cpp2::impl::deferred_init<double> r;
259+
#line 57 "pure2-autodiff.cpp2"
260+
r.construct(x * func(x, y));
261+
return std::move(r.value()); }
262+
231263
[[nodiscard]] auto ad_test::add_1_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> add_1_diff_ret{
232264
double r {0.0};
233265
double r_d {0.0};r_d = x_d + y_d;r = x + y;
@@ -298,13 +330,29 @@ auto temp_1_d {x * y_d + y * x_d};
298330
auto temp_1 {x * y}; r_d = x_d + cpp2::move(temp_1_d);r = x + cpp2::move(temp_1);
299331
return { std::move(r), std::move(r_d) };
300332
}
333+
[[nodiscard]] auto ad_test::func_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> func_diff_ret{
334+
double r {0.0};
335+
double r_d {0.0};r_d = x_d + y_d;r = x + y;
336+
return { std::move(r), std::move(r_d) };
337+
}
338+
[[nodiscard]] auto ad_test::func_call_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> func_call_diff_ret{
339+
double r {0.0};
340+
double r_d {0.0};
341+
auto temp_2 {func_diff(x, x_d, y, y_d)};
342+
343+
auto temp_1 {temp_2.r};
344+
345+
auto temp_1_d {cpp2::move(temp_2).r_d};
346+
r_d = x * cpp2::move(temp_1_d) + temp_1 * x_d;r = x * cpp2::move(temp_1);
347+
return { std::move(r), std::move(r_d) };
348+
}
301349

302-
#line 53 "pure2-autodiff.cpp2"
350+
#line 61 "pure2-autodiff.cpp2"
303351
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void{
304352
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", x_d = " + cpp2::to_string(x_d) + ", y = " + cpp2::to_string(y) + ", y_d = " + cpp2::to_string(y_d) + ") = (r = " + cpp2::to_string(ret.r) + ", r_d = " + cpp2::to_string(ret.r_d) + ")" << std::endl;
305353
}
306354

307-
#line 57 "pure2-autodiff.cpp2"
355+
#line 65 "pure2-autodiff.cpp2"
308356
auto main() -> int{
309357

310358
double x {2.0};
@@ -323,6 +371,7 @@ auto main() -> int{
323371
write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_diff(x, x_d, y, y_d));
324372
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_diff(x, x_d, y, y_d));
325373
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_diff(x, x_d, y, y_d));
326-
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_diff(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
374+
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_diff(x, x_d, y, y_d));
375+
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_diff(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
327376
}
328377

regression-tests/test-results/pure2-autodiff.cpp2.output

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,24 @@ ad_test:/* @autodiff @print */ type =
110110
return;
111111
}
112112

113+
func:(
114+
in x: double,
115+
in y: double,
116+
) -> (out r: double, ) =
117+
{
118+
r = x + y;
119+
return;
120+
}
121+
122+
func_call:(
123+
in x: double,
124+
in y: double,
125+
) -> (out r: double, ) =
126+
{
127+
r = x * func(x, y);
128+
return;
129+
}
130+
113131
add_1_diff:(
114132
in x: double,
115133
in x_d: double,
@@ -299,6 +317,39 @@ ad_test:/* @autodiff @print */ type =
299317
r = x + temp_1;
300318
return;
301319
}
320+
321+
func_diff:(
322+
in x: double,
323+
in x_d: double,
324+
in y: double,
325+
in y_d: double,
326+
) -> (
327+
out r: double = 0.0,
328+
out r_d: double = 0.0,
329+
) =
330+
{
331+
r_d = x_d + y_d;
332+
r = x + y;
333+
return;
334+
}
335+
336+
func_call_diff:(
337+
in x: double,
338+
in x_d: double,
339+
in y: double,
340+
in y_d: double,
341+
) -> (
342+
out r: double = 0.0,
343+
out r_d: double = 0.0,
344+
) =
345+
{
346+
temp_2: _ = func_diff(x, x_d, y, y_d);
347+
temp_1: _ = temp_2.r;
348+
temp_1_d: _ = temp_2.r_d;
349+
r_d = x * temp_1_d + temp_1 * x_d;
350+
r = x * temp_1;
351+
return;
352+
}
302353
}
303354
ok (all Cpp2, passes safety checks)
304355

0 commit comments

Comments
 (0)