Skip to content

Commit 05cff29

Browse files
committed
Higher order handling for special functions.
1 parent 8b5c1e9 commit 05cff29

File tree

9 files changed

+805
-580
lines changed

9 files changed

+805
-580
lines changed

regression-tests/pure2-autodiff-higher-order.cpp2

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ ad_test: @autodiff<"order=6"> @print type = {
5151
add_mul: (x: double, y: double) -> (r: double) = {
5252
r = x + x * y;
5353
}
54+
55+
func: (x: double, y: double) -> (r: double) = {
56+
r = x + y;
57+
}
58+
59+
func_call: (x: double, y: double) -> (r: double) = {
60+
r = x * func(x, y);
61+
}
62+
63+
sin_call: (x: double, y: double) -> (r: double) = {
64+
r = sin(x - y);
65+
}
5466
}
5567

5668
write_output: (func: std::string, x: double, x_d: ad_type, y: double, y_d: ad_type, ret) = {
@@ -82,8 +94,8 @@ main: () = {
8294
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_d(x, x_d, y, y_d));
8395
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_d(x, x_d, y, y_d));
8496
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
85-
// write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
86-
// write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
97+
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
98+
write_output("sin(x - y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
8799
// write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
88100
// write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_d(x, x_d, y, y_d));
89101
// write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_d(x, x_d, y, y_d));

regression-tests/pure2-autodiff.cpp2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ main: () = {
182182
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_d(x, x_d, y, y_d));
183183
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
184184
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
185-
write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
185+
write_output("sin(x - y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
186186
write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
187187
write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_d(x, x_d, y, y_d));
188188
write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_d(x, x_d, y, y_d));

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff-higher-order.cpp.execution

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,19 @@ diff(x + x * y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.
9494
d4 = 0.000000
9595
d5 = 0.000000
9696
d6 = 0.000000
97+
diff(x * func(x, y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
98+
r = 10.000000
99+
d1 = 11.000000
100+
d2 = 6.000000
101+
d3 = 0.000000
102+
d4 = 0.000000
103+
d5 = 0.000000
104+
d6 = 0.000000
105+
diff(sin(x - y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
106+
r = -0.841471
107+
d1 = -0.540302
108+
d2 = 0.841471
109+
d3 = 0.540302
110+
d4 = -0.841471
111+
d5 = -0.540302
112+
d6 = 0.841471

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ diff(x * y / x) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000)
1111
diff(x * (x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)
1212
diff(x + x * y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 8.000000, r_d = 8.000000)
1313
diff(x * func(x, y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)
14-
diff(sin(x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = -0.841471, r_d = -0.540302)
14+
diff(sin(x - y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = -0.841471, r_d = -0.540302)
1515
diff(if branch) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 2.000000, r_d = 1.000000)
1616
diff(if else branch) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 2.000000, r_d = 1.000000)
1717
diff(direct return) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 5.000000, r_d = 3.000000)

regression-tests/test-results/pure2-autodiff-higher-order.cpp

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,21 @@ using add_mul_ret = double;
8282

8383
#line 51 "pure2-autodiff-higher-order.cpp2"
8484
public: [[nodiscard]] static auto add_mul(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> add_mul_ret;
85+
using func_ret = double;
86+
87+
88+
#line 55 "pure2-autodiff-higher-order.cpp2"
89+
public: [[nodiscard]] static auto func(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_ret;
90+
using func_call_ret = double;
91+
92+
93+
#line 59 "pure2-autodiff-higher-order.cpp2"
94+
public: [[nodiscard]] static auto func_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_call_ret;
95+
using sin_call_ret = double;
96+
97+
98+
#line 63 "pure2-autodiff-higher-order.cpp2"
99+
public: [[nodiscard]] static auto sin_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> sin_call_ret;
85100
struct add_1_d_ret { double r; cpp2::taylor<double,6> r_d; };
86101

87102

@@ -131,17 +146,29 @@ struct add_mul_d_ret { double r; cpp2::taylor<double,6> r_d; };
131146

132147
public: [[nodiscard]] static auto add_mul_d(cpp2::impl::in<double> x, cpp2::impl::in<cpp2::taylor<double,6>> x_d, cpp2::impl::in<double> y, cpp2::impl::in<cpp2::taylor<double,6>> y_d) -> add_mul_d_ret;
133148

149+
struct func_d_ret { double r; cpp2::taylor<double,6> r_d; };
150+
151+
public: [[nodiscard]] static auto func_d(cpp2::impl::in<double> x, cpp2::impl::in<cpp2::taylor<double,6>> x_d, cpp2::impl::in<double> y, cpp2::impl::in<cpp2::taylor<double,6>> y_d) -> func_d_ret;
152+
153+
struct func_call_d_ret { double r; cpp2::taylor<double,6> r_d; };
154+
155+
public: [[nodiscard]] static auto func_call_d(cpp2::impl::in<double> x, cpp2::impl::in<cpp2::taylor<double,6>> x_d, cpp2::impl::in<double> y, cpp2::impl::in<cpp2::taylor<double,6>> y_d) -> func_call_d_ret;
156+
157+
struct sin_call_d_ret { double r; cpp2::taylor<double,6> r_d; };
158+
159+
public: [[nodiscard]] static auto sin_call_d(cpp2::impl::in<double> x, cpp2::impl::in<cpp2::taylor<double,6>> x_d, cpp2::impl::in<double> y, cpp2::impl::in<cpp2::taylor<double,6>> y_d) -> sin_call_d_ret;
160+
134161
public: ad_test() = default;
135162
public: ad_test(ad_test const&) = delete; /* No 'that' constructor, suppress copy */
136163
public: auto operator=(ad_test const&) -> void = delete;
137164

138165

139-
#line 54 "pure2-autodiff-higher-order.cpp2"
166+
#line 66 "pure2-autodiff-higher-order.cpp2"
140167
};
141168

142169
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<ad_type> x_d, cpp2::impl::in<double> y, cpp2::impl::in<ad_type> y_d, auto const& ret) -> void;
143170

144-
#line 65 "pure2-autodiff-higher-order.cpp2"
171+
#line 77 "pure2-autodiff-higher-order.cpp2"
145172
auto main() -> int;
146173

147174
//=== Cpp2 function definitions =================================================
@@ -232,6 +259,27 @@ auto main() -> int;
232259
r.construct(x + x * y);
233260
return std::move(r.value()); }
234261

262+
#line 55 "pure2-autodiff-higher-order.cpp2"
263+
[[nodiscard]] auto ad_test::func(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_ret{
264+
cpp2::impl::deferred_init<double> r;
265+
#line 56 "pure2-autodiff-higher-order.cpp2"
266+
r.construct(x + y);
267+
return std::move(r.value()); }
268+
269+
#line 59 "pure2-autodiff-higher-order.cpp2"
270+
[[nodiscard]] auto ad_test::func_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> func_call_ret{
271+
cpp2::impl::deferred_init<double> r;
272+
#line 60 "pure2-autodiff-higher-order.cpp2"
273+
r.construct(x * func(x, y));
274+
return std::move(r.value()); }
275+
276+
#line 63 "pure2-autodiff-higher-order.cpp2"
277+
[[nodiscard]] auto ad_test::sin_call(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> sin_call_ret{
278+
cpp2::impl::deferred_init<double> r;
279+
#line 64 "pure2-autodiff-higher-order.cpp2"
280+
r.construct(sin(x - y));
281+
return std::move(r.value()); }
282+
235283
[[nodiscard]] auto ad_test::add_1_d(cpp2::impl::in<double> x, cpp2::impl::in<cpp2::taylor<double,6>> x_d, cpp2::impl::in<double> y, cpp2::impl::in<cpp2::taylor<double,6>> y_d) -> add_1_d_ret{
236284
double r {0.0};
237285
cpp2::taylor<double,6> r_d {0.0};r_d = x_d + y_d;
@@ -330,25 +378,56 @@ auto temp_1_d {CPP2_UFCS(mul)(x_d, y_d, x, y)};
330378
return { std::move(r), std::move(r_d) };
331379
}
332380

333-
#line 56 "pure2-autodiff-higher-order.cpp2"
381+
[[nodiscard]] auto ad_test::func_d(cpp2::impl::in<double> x, cpp2::impl::in<cpp2::taylor<double,6>> x_d, cpp2::impl::in<double> y, cpp2::impl::in<cpp2::taylor<double,6>> y_d) -> func_d_ret{
382+
double r {0.0};
383+
cpp2::taylor<double,6> r_d {0.0};r_d = x_d + y_d;
384+
r = x + y;
385+
return { std::move(r), std::move(r_d) };
386+
}
387+
388+
[[nodiscard]] auto ad_test::func_call_d(cpp2::impl::in<double> x, cpp2::impl::in<cpp2::taylor<double,6>> x_d, cpp2::impl::in<double> y, cpp2::impl::in<cpp2::taylor<double,6>> y_d) -> func_call_d_ret{
389+
double r {0.0};
390+
cpp2::taylor<double,6> r_d {0.0};
391+
auto temp_2 {func_d(x, x_d, y, y_d)};
392+
393+
auto temp_1 {temp_2.r};
394+
395+
auto temp_1_d {cpp2::move(temp_2).r_d};
396+
r_d = CPP2_UFCS(mul)(x_d, cpp2::move(temp_1_d), x, temp_1);
397+
r = x * cpp2::move(temp_1);
398+
return { std::move(r), std::move(r_d) };
399+
}
400+
401+
[[nodiscard]] auto ad_test::sin_call_d(cpp2::impl::in<double> x, cpp2::impl::in<cpp2::taylor<double,6>> x_d, cpp2::impl::in<double> y, cpp2::impl::in<cpp2::taylor<double,6>> y_d) -> sin_call_d_ret{
402+
double r {0.0};
403+
cpp2::taylor<double,6> r_d {0.0};
404+
auto temp_1_d {x_d - y_d};
405+
406+
auto temp_1 {x - y};
407+
r_d = CPP2_UFCS(sin)(cpp2::move(temp_1_d), temp_1);
408+
r = sin(cpp2::move(temp_1));
409+
return { std::move(r), std::move(r_d) };
410+
}
411+
412+
#line 68 "pure2-autodiff-higher-order.cpp2"
334413
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<ad_type> x_d, cpp2::impl::in<double> y, cpp2::impl::in<ad_type> y_d, auto const& ret) -> void{
335414
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", x_d = " + cpp2::to_string(x_d) + ", y = " + cpp2::to_string(y) + ", y_d = " + cpp2::to_string(y_d) + "):" << std::endl;
336415
std::cout << " r = " + cpp2::to_string(ret.r) + "" << std::endl;
337416
{
338417
auto i{1};
339418

340-
#line 60 "pure2-autodiff-higher-order.cpp2"
419+
#line 72 "pure2-autodiff-higher-order.cpp2"
341420
for( ; cpp2::impl::cmp_less_eq(i,ad_order); i += 1 ) {
342421
std::cout << " d" + cpp2::to_string(i) + " = " + cpp2::to_string(CPP2_ASSERT_IN_BOUNDS(ret.r_d, i)) + "" << std::endl;
343422
}
344423
}
345-
#line 63 "pure2-autodiff-higher-order.cpp2"
424+
#line 75 "pure2-autodiff-higher-order.cpp2"
346425
}
347426

348-
#line 65 "pure2-autodiff-higher-order.cpp2"
427+
#line 77 "pure2-autodiff-higher-order.cpp2"
349428
auto main() -> int{
350429

351-
#line 68 "pure2-autodiff-higher-order.cpp2"
430+
#line 80 "pure2-autodiff-higher-order.cpp2"
352431
double x {2.0};
353432
ad_type x_d {1.0};
354433
double y {3.0};
@@ -365,9 +444,9 @@ auto main() -> int{
365444
write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_d(x, x_d, y, y_d));
366445
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_d(x, x_d, y, y_d));
367446
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_d(x, x_d, y, y_d));
368-
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
369-
// write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
370-
// write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
447+
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
448+
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
449+
write_output("sin(x - y)", x, x_d, y, y_d, ad_test::sin_call_d(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
371450
// write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
372451
// write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_d(x, x_d, y, y_d));
373452
// write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_d(x, x_d, y, y_d));

regression-tests/test-results/pure2-autodiff-higher-order.cpp2.output

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,33 @@ ad_test:/* @autodiff<"order=6"> @print */ type =
110110
return;
111111
}
112112

113+
func:(
114+
in x: double,
115+
in y: double,
116+
) -> (out r: double, ) =
117+
{
118+
r = x + y;
119+
return;
120+
}
121+
122+
func_call:(
123+
in x: double,
124+
in y: double,
125+
) -> (out r: double, ) =
126+
{
127+
r = x * func(x, y);
128+
return;
129+
}
130+
131+
sin_call:(
132+
in x: double,
133+
in y: double,
134+
) -> (out r: double, ) =
135+
{
136+
r = sin(x - y);
137+
return;
138+
}
139+
113140
add_1_d:(
114141
in x: double,
115142
in x_d: cpp2::taylor<double, 6>,
@@ -299,6 +326,56 @@ ad_test:/* @autodiff<"order=6"> @print */ type =
299326
r = x + temp_1;
300327
return;
301328
}
329+
330+
func_d:(
331+
in x: double,
332+
in x_d: cpp2::taylor<double, 6>,
333+
in y: double,
334+
in y_d: cpp2::taylor<double, 6>,
335+
) -> (
336+
out r: double = 0.0,
337+
out r_d: cpp2::taylor<double, 6> = 0.0,
338+
) =
339+
{
340+
r_d = x_d + y_d;
341+
r = x + y;
342+
return;
343+
}
344+
345+
func_call_d:(
346+
in x: double,
347+
in x_d: cpp2::taylor<double, 6>,
348+
in y: double,
349+
in y_d: cpp2::taylor<double, 6>,
350+
) -> (
351+
out r: double = 0.0,
352+
out r_d: cpp2::taylor<double, 6> = 0.0,
353+
) =
354+
{
355+
temp_2: _ = func_d(x, x_d, y, y_d);
356+
temp_1: _ = temp_2.r;
357+
temp_1_d: _ = temp_2.r_d;
358+
r_d = x_d.mul(temp_1_d, x, temp_1);
359+
r = x * temp_1;
360+
return;
361+
}
362+
363+
sin_call_d:(
364+
in x: double,
365+
in x_d: cpp2::taylor<double, 6>,
366+
in y: double,
367+
in y_d: cpp2::taylor<double, 6>,
368+
) -> (
369+
out r: double = 0.0,
370+
out r_d: cpp2::taylor<double, 6> = 0.0,
371+
) =
372+
{
373+
temp_1_d: _ = x_d - y_d;
374+
temp_1: _ = x - y;
375+
r_d = temp_1_d.sin(temp_1);
376+
r = sin(temp_1);
377+
return;
378+
}
302379
}
303380
ok (all Cpp2, passes safety checks)
304381

regression-tests/test-results/pure2-autodiff.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ auto main() -> int{
878878
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_d(x, x_d, y, y_d));
879879
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
880880
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
881-
write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
881+
write_output("sin(x - y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
882882
write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
883883
write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_d(x, x_d, y, y_d));
884884
write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_d(x, x_d, y, y_d));

0 commit comments

Comments
 (0)