Skip to content

Commit 358226d

Browse files
committed
Added special handling of math functions.
1 parent b3dc845 commit 358226d

File tree

6 files changed

+578
-451
lines changed

6 files changed

+578
-451
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ ad_test: @autodiff @print type = {
5656
func_call: (x: double, y: double) -> (r: double) = {
5757
r = x * func(x, y);
5858
}
59+
60+
sin_call: (x: double, y: double) -> (r: double) = {
61+
r = sin(x - y);
62+
}
5963
}
6064

6165
write_output: (func: std::string, x: double, x_d: double, y: double, y_d: double, ret) = {
@@ -82,4 +86,5 @@ main: () = {
8286
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_diff(x, x_d, y, y_d));
8387
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_diff(x, x_d, y, y_d));
8488
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_diff(x, x_d, y, y_d));
89+
write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_diff(x, x_d, y, y_d));
8590
}

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ diff(x * y / x) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000)
1111
diff(x * (x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)
1212
diff(x + x * y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 8.000000, r_d = 8.000000)
1313
diff(x * func(x, y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)
14+
diff(sin(x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = -0.841471, r_d = -0.540302)

regression-tests/test-results/pure2-autodiff.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ using func_call_ret = double;
8888

8989
#line 56 "pure2-autodiff.cpp2"
9090
public: [[nodiscard]] static auto func_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_call_ret;
91+
using sin_call_ret = double;
92+
93+
94+
#line 60 "pure2-autodiff.cpp2"
95+
public: [[nodiscard]] static auto sin_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> sin_call_ret;
9196
struct add_1_diff_ret { double r; double r_d; };
9297

9398

@@ -145,17 +150,21 @@ struct func_call_diff_ret { double r; double r_d; };
145150

146151
public: [[nodiscard]] static auto func_call_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> func_call_diff_ret;
147152

153+
struct sin_call_diff_ret { double r; double r_d; };
154+
155+
public: [[nodiscard]] static auto sin_call_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> sin_call_diff_ret;
156+
148157
public: ad_test() = default;
149158
public: ad_test(ad_test const&) = delete; /* No 'that' constructor, suppress copy */
150159
public: auto operator=(ad_test const&) -> void = delete;
151160

152161

153-
#line 59 "pure2-autodiff.cpp2"
162+
#line 63 "pure2-autodiff.cpp2"
154163
};
155164

156165
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void;
157166

158-
#line 65 "pure2-autodiff.cpp2"
167+
#line 69 "pure2-autodiff.cpp2"
159168
auto main() -> int;
160169

161170
//=== Cpp2 function definitions =================================================
@@ -260,6 +269,13 @@ auto main() -> int;
260269
r.construct(x * func(x, y));
261270
return std::move(r.value()); }
262271

272+
#line 60 "pure2-autodiff.cpp2"
273+
[[nodiscard]] auto ad_test::sin_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> sin_call_ret{
274+
cpp2::impl::deferred_init<double> r;
275+
#line 61 "pure2-autodiff.cpp2"
276+
r.construct(sin(x - y));
277+
return std::move(r.value()); }
278+
263279
[[nodiscard]] auto ad_test::add_1_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> add_1_diff_ret{
264280
double r {0.0};
265281
double r_d {0.0};r_d = x_d + y_d;r = x + y;
@@ -347,12 +363,20 @@ auto temp_2 {func_diff(x, x_d, y, y_d)};
347363
return { std::move(r), std::move(r_d) };
348364
}
349365

350-
#line 61 "pure2-autodiff.cpp2"
366+
[[nodiscard]] auto ad_test::sin_call_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> sin_call_diff_ret{
367+
double r {0.0};
368+
double r_d {0.0};
369+
auto temp_1_d {x_d - y_d};
370+
auto temp_1 {x - y}; r_d = cos(temp_1) * cpp2::move(temp_1_d);
371+
r = sin(cpp2::move(temp_1));
372+
return { std::move(r), std::move(r_d) }; }
373+
374+
#line 65 "pure2-autodiff.cpp2"
351375
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void{
352376
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", x_d = " + cpp2::to_string(x_d) + ", y = " + cpp2::to_string(y) + ", y_d = " + cpp2::to_string(y_d) + ") = (r = " + cpp2::to_string(ret.r) + ", r_d = " + cpp2::to_string(ret.r_d) + ")" << std::endl;
353377
}
354378

355-
#line 65 "pure2-autodiff.cpp2"
379+
#line 69 "pure2-autodiff.cpp2"
356380
auto main() -> int{
357381

358382
double x {2.0};
@@ -372,6 +396,7 @@ auto main() -> int{
372396
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_diff(x, x_d, y, y_d));
373397
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_diff(x, x_d, y, y_d));
374398
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_diff(x, x_d, y, y_d));
375-
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_diff(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
399+
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_diff(x, x_d, y, y_d));
400+
write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_diff(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
376401
}
377402

regression-tests/test-results/pure2-autodiff.cpp2.output

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ ad_test:/* @autodiff @print */ type =
128128
return;
129129
}
130130

131+
sin_call:(
132+
in x: double,
133+
in y: double,
134+
) -> (out r: double, ) =
135+
{
136+
r = sin(x - y);
137+
return;
138+
}
139+
131140
add_1_diff:(
132141
in x: double,
133142
in x_d: double,
@@ -350,6 +359,23 @@ ad_test:/* @autodiff @print */ type =
350359
r = x * temp_1;
351360
return;
352361
}
362+
363+
sin_call_diff:(
364+
in x: double,
365+
in x_d: double,
366+
in y: double,
367+
in y_d: double,
368+
) -> (
369+
out r: double = 0.0,
370+
out r_d: double = 0.0,
371+
) =
372+
{
373+
temp_1_d: _ = x_d - y_d;
374+
temp_1: _ = x - y;
375+
r_d = cos(temp_1) * temp_1_d;
376+
r = sin(temp_1);
377+
return;
378+
}
353379
}
354380
ok (all Cpp2, passes safety checks)
355381

0 commit comments

Comments
 (0)