Skip to content

Commit 7504683

Browse files
committed
Added handling of function calls for reverse.
1 parent 3898548 commit 7504683

File tree

6 files changed

+795
-549
lines changed

6 files changed

+795
-549
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,18 @@ ad_test_reverse: @autodiff<"reverse"> @print type = {
244244
sin_call: (x: double, y: double) -> (r: double) = {
245245
r = sin(x - y);
246246
}
247+
248+
func: (x: double, y: double) -> (ret: double) = {
249+
ret = x + y;
250+
}
251+
252+
func_call: (x: double, y: double) -> (r: double) = {
253+
r = x * func(x, y);
254+
}
255+
256+
func_outer_call: (x: double, y: double) -> (r: double) = {
257+
r = x * func_outer(x, y);
258+
}
247259
}
248260
}
249261

@@ -319,6 +331,8 @@ main: () = {
319331
write_output_reverse("x * (x + y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_add_b(x, x_b, y, y_b, w_b));
320332
write_output_reverse("x + x * y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_mul_b(x, x_b, y, y_b, w_b));
321333
write_output_reverse("sin(x-y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::sin_call_b(x, x_b, y, y_b, w_b));
334+
write_output_reverse("x * func(x-y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::func_call_b(x, x_b, y, y_b, w_b));
335+
write_output_reverse("x * func_outer(x-y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::func_outer_call_b(x, x_b, y, y_b, w_b));
322336

323337
_ = x_b;
324338
_ = y_b;

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,6 @@ diff(x * y / x) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 3.000000,
4141
diff(x * (x + y)) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 10.000000, x_b = 7.000000, y_b = 2.000000)
4242
diff(x + x * y) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 8.000000, x_b = 4.000000, y_b = 2.000000)
4343
diff(sin(x-y)) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = -0.841471, x_b = 0.540302, y_b = -0.540302)
44+
diff(x * func(x-y)) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 10.000000, x_b = 7.000000, y_b = 2.000000)
45+
diff(x * func_outer(x-y)) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 10.000000, x_b = 7.000000, y_b = 2.000000)
4446
2nd order diff of x*x at 2.000000 = 2.000000

regression-tests/test-results/pure2-autodiff.cpp

Lines changed: 112 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class ad_test;
2121
#line 194 "pure2-autodiff.cpp2"
2222
class ad_test_reverse;
2323

24-
#line 248 "pure2-autodiff.cpp2"
24+
#line 260 "pure2-autodiff.cpp2"
2525
}
2626

2727
class ad_test_twice;
@@ -355,6 +355,11 @@ public: [[nodiscard]] static auto type_outer_call_d(cpp2::impl::in<double> x, cp
355355
#line 192 "pure2-autodiff.cpp2"
356356
};
357357

358+
using func_outer_b_ret = double;
359+
360+
[[nodiscard]] auto func_outer_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& ret_b) -> func_outer_b_ret;
361+
362+
#line 194 "pure2-autodiff.cpp2"
358363
class ad_test_reverse {
359364
using add_1_ret = double;
360365

@@ -421,6 +426,21 @@ using sin_call_ret = double;
421426

422427
#line 244 "pure2-autodiff.cpp2"
423428
public: [[nodiscard]] static auto sin_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> sin_call_ret;
429+
using func_ret = double;
430+
431+
432+
#line 248 "pure2-autodiff.cpp2"
433+
public: [[nodiscard]] static auto func(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_ret;
434+
using func_call_ret = double;
435+
436+
437+
#line 252 "pure2-autodiff.cpp2"
438+
public: [[nodiscard]] static auto func_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_call_ret;
439+
using func_outer_call_ret = double;
440+
441+
442+
#line 256 "pure2-autodiff.cpp2"
443+
public: [[nodiscard]] static auto func_outer_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_outer_call_ret;
424444
using add_1_b_ret = double;
425445

426446
public: [[nodiscard]] static auto add_1_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> add_1_b_ret;
@@ -461,19 +481,28 @@ public: [[nodiscard]] static auto add_mul_b(cpp2::impl::in<double> x, double& x_
461481
using sin_call_b_ret = double;
462482
public: [[nodiscard]] static auto sin_call_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> sin_call_b_ret;
463483

484+
using func_b_ret = double;
485+
public: [[nodiscard]] static auto func_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& ret_b) -> func_b_ret;
486+
487+
using func_call_b_ret = double;
488+
public: [[nodiscard]] static auto func_call_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> func_call_b_ret;
489+
490+
using func_outer_call_b_ret = double;
491+
public: [[nodiscard]] static auto func_outer_call_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> func_outer_call_b_ret;
492+
464493
public: ad_test_reverse() = default;
465494
public: ad_test_reverse(ad_test_reverse const&) = delete; /* No 'that' constructor, suppress copy */
466495
public: auto operator=(ad_test_reverse const&) -> void = delete;
467496

468497

469-
#line 247 "pure2-autodiff.cpp2"
498+
#line 259 "pure2-autodiff.cpp2"
470499
};
471500
}
472501

473502
class ad_test_twice {
474503
using mul_1_ret = double;
475504

476-
#line 251 "pure2-autodiff.cpp2"
505+
#line 263 "pure2-autodiff.cpp2"
477506
public: [[nodiscard]] static auto mul_1(cpp2::impl::in<double> x) -> mul_1_ret;
478507
struct mul_1_d_ret { double r; double r_d; };
479508

@@ -493,15 +522,15 @@ public: [[nodiscard]] static auto mul_1_d_d2(cpp2::impl::in<double> x, cpp2::imp
493522
public: auto operator=(ad_test_twice const&) -> void = delete;
494523

495524

496-
#line 254 "pure2-autodiff.cpp2"
525+
#line 266 "pure2-autodiff.cpp2"
497526
};
498527

499528
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void;
500529

501-
#line 260 "pure2-autodiff.cpp2"
530+
#line 272 "pure2-autodiff.cpp2"
502531
auto write_output_reverse(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b, auto const& ret) -> void;
503532

504-
#line 267 "pure2-autodiff.cpp2"
533+
#line 279 "pure2-autodiff.cpp2"
505534
auto main() -> int;
506535

507536
//=== Cpp2 function definitions =================================================
@@ -941,9 +970,9 @@ double temp_1_d {-x_d};
941970
double r_d {0.0};
942971
auto temp_1 {func_d(x, x_d, y, y_d)};
943972

944-
double temp_2_d {temp_1.ret_d};
973+
double temp_2_d {cpp2::move(temp_1).ret_d};
945974

946-
double temp_2 {cpp2::move(temp_1).ret};
975+
double temp_2 {func(x, y)};
947976
r_d = temp_2 * x_d + x * cpp2::move(temp_2_d);
948977
r = x * cpp2::move(temp_2);
949978
return { std::move(r), std::move(r_d) };
@@ -954,9 +983,9 @@ auto temp_1 {func_d(x, x_d, y, y_d)};
954983
double r_d {0.0};
955984
auto temp_1 {func_outer_d(x, x_d, y, y_d)};
956985

957-
double temp_2_d {temp_1.ret_d};
986+
double temp_2_d {cpp2::move(temp_1).ret_d};
958987

959-
double temp_2 {cpp2::move(temp_1).ret};
988+
double temp_2 {func_outer(x, y)};
960989
r_d = temp_2 * x_d + x * cpp2::move(temp_2_d);
961990
r = x * cpp2::move(temp_2);
962991
return { std::move(r), std::move(r_d) };
@@ -1151,12 +1180,19 @@ type_outer_d t_d {};
11511180
t_d.a_d = x_d;
11521181
t.a = x;
11531182

1154-
auto temp_1 {CPP2_UFCS(add_d)(cpp2::move(t), cpp2::move(t_d), y, y_d)};
1155-
r_d = temp_1.r_d;
1156-
r = cpp2::move(temp_1).r;
1183+
auto temp_1 {CPP2_UFCS(add_d)(t, cpp2::move(t_d), y, y_d)};
1184+
r_d = cpp2::move(temp_1).r_d;
1185+
r = CPP2_UFCS(add)(cpp2::move(t), y);
11571186
return { std::move(r), std::move(r_d) };
11581187
}
11591188

1189+
[[nodiscard]] auto func_outer_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& ret_b) -> func_outer_b_ret{
1190+
double ret {0.0};ret = x + y;
1191+
x_b += ret_b;
1192+
y_b += ret_b;
1193+
ret_b = 0.0;
1194+
return ret; }
1195+
11601196
#line 196 "pure2-autodiff.cpp2"
11611197
[[nodiscard]] auto ad_test_reverse::add_1(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> add_1_ret{
11621198
cpp2::impl::deferred_init<double> r;
@@ -1248,6 +1284,27 @@ type_outer_d t_d {};
12481284
r.construct(sin(x - y));
12491285
return std::move(r.value()); }
12501286

1287+
#line 248 "pure2-autodiff.cpp2"
1288+
[[nodiscard]] auto ad_test_reverse::func(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_ret{
1289+
cpp2::impl::deferred_init<double> ret;
1290+
#line 249 "pure2-autodiff.cpp2"
1291+
ret.construct(x + y);
1292+
return std::move(ret.value()); }
1293+
1294+
#line 252 "pure2-autodiff.cpp2"
1295+
[[nodiscard]] auto ad_test_reverse::func_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_call_ret{
1296+
cpp2::impl::deferred_init<double> r;
1297+
#line 253 "pure2-autodiff.cpp2"
1298+
r.construct(x * func(x, y));
1299+
return std::move(r.value()); }
1300+
1301+
#line 256 "pure2-autodiff.cpp2"
1302+
[[nodiscard]] auto ad_test_reverse::func_outer_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_outer_call_ret{
1303+
cpp2::impl::deferred_init<double> r;
1304+
#line 257 "pure2-autodiff.cpp2"
1305+
r.construct(x * func_outer(x, y));
1306+
return std::move(r.value()); }
1307+
12511308
[[nodiscard]] auto ad_test_reverse::add_1_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> add_1_b_ret{
12521309
double r {0.0};r = x + y;
12531310
x_b += r_b;
@@ -1383,13 +1440,46 @@ double temp_1_b {0.0};
13831440
temp_1_b = 0.0;
13841441
return r; }
13851442

1386-
#line 248 "pure2-autodiff.cpp2"
1443+
[[nodiscard]] auto ad_test_reverse::func_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& ret_b) -> func_b_ret{
1444+
double ret {0.0};ret = x + y;
1445+
x_b += ret_b;
1446+
y_b += ret_b;
1447+
ret_b = 0.0;
1448+
return ret; }
1449+
1450+
[[nodiscard]] auto ad_test_reverse::func_call_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> func_call_b_ret{
1451+
double r {0.0};
1452+
double temp_2_b {0.0};
1453+
1454+
double temp_2 {func(x, y)};
1455+
r = x * temp_2;
1456+
x_b += cpp2::move(temp_2) * r_b;
1457+
temp_2_b += x * r_b;
1458+
r_b = 0.0;
1459+
static_cast<void>(func_b(x, x_b, y, y_b, temp_2_b));
1460+
temp_2_b = 0.0;
1461+
return r; }
1462+
1463+
[[nodiscard]] auto ad_test_reverse::func_outer_call_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> func_outer_call_b_ret{
1464+
double r {0.0};
1465+
double temp_2_b {0.0};
1466+
1467+
double temp_2 {func_outer(x, y)};
1468+
r = x * temp_2;
1469+
x_b += cpp2::move(temp_2) * r_b;
1470+
temp_2_b += x * r_b;
1471+
r_b = 0.0;
1472+
static_cast<void>(func_outer_b(x, x_b, y, y_b, temp_2_b));
1473+
temp_2_b = 0.0;
1474+
return r; }
1475+
1476+
#line 260 "pure2-autodiff.cpp2"
13871477
}
13881478

1389-
#line 251 "pure2-autodiff.cpp2"
1479+
#line 263 "pure2-autodiff.cpp2"
13901480
[[nodiscard]] auto ad_test_twice::mul_1(cpp2::impl::in<double> x) -> mul_1_ret{
13911481
cpp2::impl::deferred_init<double> r;
1392-
#line 252 "pure2-autodiff.cpp2"
1482+
#line 264 "pure2-autodiff.cpp2"
13931483
r.construct(x * x);
13941484
return std::move(r.value()); }
13951485

@@ -1426,20 +1516,20 @@ double temp_1_d2 {x_d * x_d2 + x * x_d_d2};
14261516
return { std::move(r), std::move(r_d2), std::move(r_d), std::move(r_d_d2) };
14271517
}
14281518

1429-
#line 256 "pure2-autodiff.cpp2"
1519+
#line 268 "pure2-autodiff.cpp2"
14301520
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void{
14311521
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", x_d = " + cpp2::to_string(x_d) + ", y = " + cpp2::to_string(y) + ", y_d = " + cpp2::to_string(y_d) + ") = (r = " + cpp2::to_string(ret.r) + ", r_d = " + cpp2::to_string(ret.r_d) + ")" << std::endl;
14321522
}
14331523

1434-
#line 260 "pure2-autodiff.cpp2"
1524+
#line 272 "pure2-autodiff.cpp2"
14351525
auto write_output_reverse(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b, auto const& ret) -> void{
14361526
r_b = 1.0;
14371527
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", y = " + cpp2::to_string(y) + ", r_b = " + cpp2::to_string(r_b) + ") = (r = " + cpp2::to_string(ret) + ", x_b = " + cpp2::to_string(x_b) + ", y_b = " + cpp2::to_string(y_b) + ")" << std::endl;
14381528
x_b = 0.0;
14391529
y_b = 0.0;
14401530
}
14411531

1442-
#line 267 "pure2-autodiff.cpp2"
1532+
#line 279 "pure2-autodiff.cpp2"
14431533
auto main() -> int{
14441534

14451535
double x {2.0};
@@ -1494,7 +1584,9 @@ auto main() -> int{
14941584
write_output_reverse("x * y / x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_div_2_b(x, x_b, y, y_b, w_b));
14951585
write_output_reverse("x * (x + y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_add_b(x, x_b, y, y_b, w_b));
14961586
write_output_reverse("x + x * y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_mul_b(x, x_b, y, y_b, w_b));
1497-
write_output_reverse("sin(x-y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::sin_call_b(x, x_b, cpp2::move(y), y_b, w_b));
1587+
write_output_reverse("sin(x-y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::sin_call_b(x, x_b, y, y_b, w_b));
1588+
write_output_reverse("x * func(x-y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::func_call_b(x, x_b, y, y_b, w_b));
1589+
write_output_reverse("x * func_outer(x-y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::func_outer_call_b(x, x_b, cpp2::move(y), y_b, w_b));
14981590

14991591
static_cast<void>(cpp2::move(x_b));
15001592
static_cast<void>(cpp2::move(y_b));

0 commit comments

Comments
 (0)