Skip to content

Commit 3898548

Browse files
committed
Handling of special functions for reverse mode.
1 parent 7b3b9c8 commit 3898548

File tree

6 files changed

+858
-721
lines changed

6 files changed

+858
-721
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@ ad_test_reverse: @autodiff<"reverse"> @print type = {
240240
add_mul: (x: double, y: double) -> (r: double) = {
241241
r = x + x * y;
242242
}
243+
244+
sin_call: (x: double, y: double) -> (r: double) = {
245+
r = sin(x - y);
246+
}
243247
}
244248
}
245249

@@ -314,6 +318,7 @@ main: () = {
314318
write_output_reverse("x * y / x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_div_2_b(x, x_b, y, y_b, w_b));
315319
write_output_reverse("x * (x + y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_add_b(x, x_b, y, y_b, w_b));
316320
write_output_reverse("x + x * y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_mul_b(x, x_b, y, y_b, w_b));
321+
write_output_reverse("sin(x-y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::sin_call_b(x, x_b, y, y_b, w_b));
317322

318323
_ = x_b;
319324
_ = y_b;

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@ diff(x / y / y) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 0.222222,
4040
diff(x * y / x) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 3.000000, x_b = 0.000000, y_b = 1.000000)
4141
diff(x * (x + y)) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 10.000000, x_b = 7.000000, y_b = 2.000000)
4242
diff(x + x * y) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 8.000000, x_b = 4.000000, y_b = 2.000000)
43+
diff(sin(x-y)) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = -0.841471, x_b = 0.540302, y_b = -0.540302)
4344
2nd order diff of x*x at 2.000000 = 2.000000

regression-tests/test-results/pure2-autodiff.cpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class ad_test;
2121
#line 194 "pure2-autodiff.cpp2"
2222
class ad_test_reverse;
2323

24-
#line 244 "pure2-autodiff.cpp2"
24+
#line 248 "pure2-autodiff.cpp2"
2525
}
2626

2727
class ad_test_twice;
@@ -416,6 +416,11 @@ using add_mul_ret = double;
416416

417417
#line 240 "pure2-autodiff.cpp2"
418418
public: [[nodiscard]] static auto add_mul(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> add_mul_ret;
419+
using sin_call_ret = double;
420+
421+
422+
#line 244 "pure2-autodiff.cpp2"
423+
public: [[nodiscard]] static auto sin_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> sin_call_ret;
419424
using add_1_b_ret = double;
420425

421426
public: [[nodiscard]] static auto add_1_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> add_1_b_ret;
@@ -453,19 +458,22 @@ public: [[nodiscard]] static auto mul_add_b(cpp2::impl::in<double> x, double& x_
453458
using add_mul_b_ret = double;
454459
public: [[nodiscard]] static auto add_mul_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> add_mul_b_ret;
455460

461+
using sin_call_b_ret = double;
462+
public: [[nodiscard]] static auto sin_call_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> sin_call_b_ret;
463+
456464
public: ad_test_reverse() = default;
457465
public: ad_test_reverse(ad_test_reverse const&) = delete; /* No 'that' constructor, suppress copy */
458466
public: auto operator=(ad_test_reverse const&) -> void = delete;
459467

460468

461-
#line 243 "pure2-autodiff.cpp2"
469+
#line 247 "pure2-autodiff.cpp2"
462470
};
463471
}
464472

465473
class ad_test_twice {
466474
using mul_1_ret = double;
467475

468-
#line 247 "pure2-autodiff.cpp2"
476+
#line 251 "pure2-autodiff.cpp2"
469477
public: [[nodiscard]] static auto mul_1(cpp2::impl::in<double> x) -> mul_1_ret;
470478
struct mul_1_d_ret { double r; double r_d; };
471479

@@ -485,15 +493,15 @@ public: [[nodiscard]] static auto mul_1_d_d2(cpp2::impl::in<double> x, cpp2::imp
485493
public: auto operator=(ad_test_twice const&) -> void = delete;
486494

487495

488-
#line 250 "pure2-autodiff.cpp2"
496+
#line 254 "pure2-autodiff.cpp2"
489497
};
490498

491499
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void;
492500

493-
#line 256 "pure2-autodiff.cpp2"
501+
#line 260 "pure2-autodiff.cpp2"
494502
auto write_output_reverse(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b, auto const& ret) -> void;
495503

496-
#line 263 "pure2-autodiff.cpp2"
504+
#line 267 "pure2-autodiff.cpp2"
497505
auto main() -> int;
498506

499507
//=== Cpp2 function definitions =================================================
@@ -1233,6 +1241,13 @@ type_outer_d t_d {};
12331241
r.construct(x + x * y);
12341242
return std::move(r.value()); }
12351243

1244+
#line 244 "pure2-autodiff.cpp2"
1245+
[[nodiscard]] auto ad_test_reverse::sin_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> sin_call_ret{
1246+
cpp2::impl::deferred_init<double> r;
1247+
#line 245 "pure2-autodiff.cpp2"
1248+
r.construct(sin(x - y));
1249+
return std::move(r.value()); }
1250+
12361251
[[nodiscard]] auto ad_test_reverse::add_1_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> add_1_b_ret{
12371252
double r {0.0};r = x + y;
12381253
x_b += r_b;
@@ -1355,13 +1370,26 @@ double temp_1_b {0.0};
13551370
temp_1_b = 0.0;
13561371
return r; }
13571372

1358-
#line 244 "pure2-autodiff.cpp2"
1373+
[[nodiscard]] auto ad_test_reverse::sin_call_b(cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b) -> sin_call_b_ret{
1374+
double r {0.0};
1375+
double temp_1_b {0.0};
1376+
1377+
double temp_1 {x - y};
1378+
r = sin(temp_1);
1379+
temp_1_b += cos(cpp2::move(temp_1)) * r_b;
1380+
r_b = 0.0;
1381+
x_b += temp_1_b;
1382+
y_b -= temp_1_b;
1383+
temp_1_b = 0.0;
1384+
return r; }
1385+
1386+
#line 248 "pure2-autodiff.cpp2"
13591387
}
13601388

1361-
#line 247 "pure2-autodiff.cpp2"
1389+
#line 251 "pure2-autodiff.cpp2"
13621390
[[nodiscard]] auto ad_test_twice::mul_1(cpp2::impl::in<double> x) -> mul_1_ret{
13631391
cpp2::impl::deferred_init<double> r;
1364-
#line 248 "pure2-autodiff.cpp2"
1392+
#line 252 "pure2-autodiff.cpp2"
13651393
r.construct(x * x);
13661394
return std::move(r.value()); }
13671395

@@ -1398,20 +1426,20 @@ double temp_1_d2 {x_d * x_d2 + x * x_d_d2};
13981426
return { std::move(r), std::move(r_d2), std::move(r_d), std::move(r_d_d2) };
13991427
}
14001428

1401-
#line 252 "pure2-autodiff.cpp2"
1429+
#line 256 "pure2-autodiff.cpp2"
14021430
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void{
14031431
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", x_d = " + cpp2::to_string(x_d) + ", y = " + cpp2::to_string(y) + ", y_d = " + cpp2::to_string(y_d) + ") = (r = " + cpp2::to_string(ret.r) + ", r_d = " + cpp2::to_string(ret.r_d) + ")" << std::endl;
14041432
}
14051433

1406-
#line 256 "pure2-autodiff.cpp2"
1434+
#line 260 "pure2-autodiff.cpp2"
14071435
auto write_output_reverse(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, double& x_b, cpp2::impl::in<double> y, double& y_b, double& r_b, auto const& ret) -> void{
14081436
r_b = 1.0;
14091437
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", y = " + cpp2::to_string(y) + ", r_b = " + cpp2::to_string(r_b) + ") = (r = " + cpp2::to_string(ret) + ", x_b = " + cpp2::to_string(x_b) + ", y_b = " + cpp2::to_string(y_b) + ")" << std::endl;
14101438
x_b = 0.0;
14111439
y_b = 0.0;
14121440
}
14131441

1414-
#line 263 "pure2-autodiff.cpp2"
1442+
#line 267 "pure2-autodiff.cpp2"
14151443
auto main() -> int{
14161444

14171445
double x {2.0};
@@ -1465,7 +1493,8 @@ auto main() -> int{
14651493
write_output_reverse("x / y / y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::div_2_b(x, x_b, y, y_b, w_b));
14661494
write_output_reverse("x * y / x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_div_2_b(x, x_b, y, y_b, w_b));
14671495
write_output_reverse("x * (x + y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_add_b(x, x_b, y, y_b, w_b));
1468-
write_output_reverse("x + x * y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_mul_b(x, x_b, cpp2::move(y), y_b, w_b));
1496+
write_output_reverse("x + x * y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_mul_b(x, x_b, y, y_b, w_b));
1497+
write_output_reverse("sin(x-y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::sin_call_b(x, x_b, cpp2::move(y), y_b, w_b));
14691498

14701499
static_cast<void>(cpp2::move(x_b));
14711500
static_cast<void>(cpp2::move(y_b));

regression-tests/test-results/pure2-autodiff.cpp2.output

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,15 @@ ad_test_reverse:/* @autodiff<"reverse"> @print */ type =
10011001
return;
10021002
}
10031003

1004+
sin_call:(
1005+
in x: double,
1006+
in y: double,
1007+
) -> (out r: double, ) =
1008+
{
1009+
r = sin(x - y);
1010+
return;
1011+
}
1012+
10041013
add_1_b:(
10051014
in x: double,
10061015
inout x_b: double,
@@ -1208,6 +1217,25 @@ ad_test_reverse:/* @autodiff<"reverse"> @print */ type =
12081217
temp_1_b = 0.0;
12091218
return;
12101219
}
1220+
1221+
sin_call_b:(
1222+
in x: double,
1223+
inout x_b: double,
1224+
in y: double,
1225+
inout y_b: double,
1226+
inout r_b: double,
1227+
) -> (out r: double = 0.0, ) =
1228+
{
1229+
temp_1_b: double = 0.0;
1230+
temp_1: double = x - y;
1231+
r = sin(temp_1);
1232+
temp_1_b += cos(temp_1) * r_b;
1233+
r_b = 0.0;
1234+
x_b += temp_1_b;
1235+
y_b -= temp_1_b;
1236+
temp_1_b = 0.0;
1237+
return;
1238+
}
12111239
}
12121240

12131241

0 commit comments

Comments
 (0)