Skip to content

Commit 065c3e9

Browse files
committed
Added second order test.
1 parent 093f291 commit 065c3e9

File tree

4 files changed

+142
-3
lines changed

4 files changed

+142
-3
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ ad_test: @autodiff @print type = {
153153
}
154154
}
155155

156+
ad_test_twice: @autodiff @autodiff<"suffix=_d2"> @print type = {
157+
mul_1: (x: double) -> (r: double) = {
158+
r = x * x;
159+
}
160+
}
156161

157162
write_output: (func: std::string, x: double, x_d: double, y: double, y_d: double, ret) = {
158163
std::cout << "diff((func)$) at (x = (x)$, x_d = (x_d)$, y = (y)$, y_d = (y_d)$) = (r = (ret.r)$, r_d = (ret.r_d)$)" << std::endl;
@@ -190,4 +195,7 @@ main: () = {
190195
write_output("while loop", x, x_d, y, y_d, ad_test::while_loop_d(x, x_d, y, y_d));
191196
write_output("do while loop", x, x_d, y, y_d, ad_test::do_while_loop_d(x, x_d, y, y_d));
192197
write_output("for loop", x, x_d, y, y_d, ad_test::for_loop_d(x, x_d, y, y_d));
198+
199+
r_twice := ad_test_twice::mul_1_d_d2(x, x_d, x_d, 0.0);
200+
std::cout << "2nd order diff of x*x at (x)$ = (r_twice.r_d_d2)$" << std::endl;
193201
}

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ diff(intermediate no init) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d =
2323
diff(while loop) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 8.000000, r_d = 5.000000)
2424
diff(do while loop) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 8.000000, r_d = 5.000000)
2525
diff(for loop) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 5.000000, r_d = 3.000000)
26+
2nd order diff of x*x at 2.000000 = 2.000000

regression-tests/test-results/pure2-autodiff.cpp

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#line 2 "pure2-autodiff.cpp2"
1212
class ad_test;
1313

14+
#line 156 "pure2-autodiff.cpp2"
15+
class ad_test_twice;
16+
1417

1518
//=== Cpp2 type definitions and function declarations ===========================
1619

@@ -259,10 +262,35 @@ public: [[nodiscard]] static auto for_loop_d(cpp2::impl::in<double> x, cpp2::imp
259262
#line 154 "pure2-autodiff.cpp2"
260263
};
261264

265+
class ad_test_twice {
266+
using mul_1_ret = double;
267+
262268
#line 157 "pure2-autodiff.cpp2"
269+
public: [[nodiscard]] static auto mul_1(cpp2::impl::in<double> x) -> mul_1_ret;
270+
struct mul_1_d_ret { double r; double r_d; };
271+
272+
273+
public: [[nodiscard]] static auto mul_1_d(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d) -> mul_1_d_ret;
274+
275+
struct mul_1_d2_ret { double r; double r_d2; };
276+
277+
public: [[nodiscard]] static auto mul_1_d2(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d2) -> mul_1_d2_ret;
278+
279+
struct mul_1_d_d2_ret { double r; double r_d2; double r_d; double r_d_d2; };
280+
281+
public: [[nodiscard]] static auto mul_1_d_d2(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d2, cpp2::impl::in<double> x_d, cpp2::impl::in<double> x_d_d2) -> mul_1_d_d2_ret;
282+
283+
public: ad_test_twice() = default;
284+
public: ad_test_twice(ad_test_twice const&) = delete; /* No 'that' constructor, suppress copy */
285+
public: auto operator=(ad_test_twice const&) -> void = delete;
286+
287+
288+
#line 160 "pure2-autodiff.cpp2"
289+
};
290+
263291
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void;
264292

265-
#line 161 "pure2-autodiff.cpp2"
293+
#line 166 "pure2-autodiff.cpp2"
266294
auto main() -> int;
267295

268296
//=== Cpp2 function definitions =================================================
@@ -788,11 +816,51 @@ auto const& t_d{*cpp2::impl::assert_not_null(t_d_iter)};
788816
}
789817

790818
#line 157 "pure2-autodiff.cpp2"
819+
[[nodiscard]] auto ad_test_twice::mul_1(cpp2::impl::in<double> x) -> mul_1_ret{
820+
cpp2::impl::deferred_init<double> r;
821+
#line 158 "pure2-autodiff.cpp2"
822+
r.construct(x * x);
823+
return std::move(r.value()); }
824+
825+
[[nodiscard]] auto ad_test_twice::mul_1_d(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d) -> mul_1_d_ret{
826+
double r {0.0};
827+
double r_d {0.0};r_d = x * x_d + x * x_d;
828+
r = x * x;
829+
return { std::move(r), std::move(r_d) };
830+
}
831+
832+
[[nodiscard]] auto ad_test_twice::mul_1_d2(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d2) -> mul_1_d2_ret{
833+
double r {0.0};
834+
double r_d2 {0.0};r_d2 = x * x_d2 + x * x_d2;
835+
r = x * x;
836+
return { std::move(r), std::move(r_d2) };
837+
}
838+
839+
[[nodiscard]] auto ad_test_twice::mul_1_d_d2(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d2, cpp2::impl::in<double> x_d, cpp2::impl::in<double> x_d_d2) -> mul_1_d_d2_ret{
840+
double r {0.0};
841+
double r_d2 {0.0};
842+
double r_d {0.0};
843+
double r_d_d2 {0.0};
844+
auto temp_1_d2 {x * x_d_d2 + x_d * x_d2};
845+
846+
auto temp_1 {x * x_d};
847+
848+
auto temp_2_d2 {x * x_d_d2 + x_d * x_d2};
849+
850+
auto temp_2 {x * x_d};
851+
r_d_d2 = cpp2::move(temp_1_d2) + cpp2::move(temp_2_d2);
852+
r_d = cpp2::move(temp_1) + cpp2::move(temp_2);
853+
r_d2 = x * x_d2 + x * x_d2;
854+
r = x * x;
855+
return { std::move(r), std::move(r_d2), std::move(r_d), std::move(r_d_d2) };
856+
}
857+
858+
#line 162 "pure2-autodiff.cpp2"
791859
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void{
792860
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", x_d = " + cpp2::to_string(x_d) + ", y = " + cpp2::to_string(y) + ", y_d = " + cpp2::to_string(y_d) + ") = (r = " + cpp2::to_string(ret.r) + ", r_d = " + cpp2::to_string(ret.r_d) + ")" << std::endl;
793861
}
794862

795-
#line 161 "pure2-autodiff.cpp2"
863+
#line 166 "pure2-autodiff.cpp2"
796864
auto main() -> int{
797865

798866
double x {2.0};
@@ -824,6 +892,9 @@ auto main() -> int{
824892
write_output("intermediate no init", x, x_d, y, y_d, ad_test::intermediate_no_init_d(x, x_d, y, y_d));
825893
write_output("while loop", x, x_d, y, y_d, ad_test::while_loop_d(x, x_d, y, y_d));
826894
write_output("do while loop", x, x_d, y, y_d, ad_test::do_while_loop_d(x, x_d, y, y_d));
827-
write_output("for loop", x, x_d, y, y_d, ad_test::for_loop_d(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
895+
write_output("for loop", x, x_d, y, y_d, ad_test::for_loop_d(x, x_d, cpp2::move(y), cpp2::move(y_d)));
896+
897+
auto r_twice {ad_test_twice::mul_1_d_d2(x, x_d, cpp2::move(x_d), 0.0)};
898+
std::cout << "2nd order diff of x*x at " + cpp2::to_string(cpp2::move(x)) + " = " + cpp2::to_string(cpp2::move(r_twice).r_d_d2) + "" << std::endl;
828899
}
829900

regression-tests/test-results/pure2-autodiff.cpp2.output

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,5 +755,64 @@ ad_test:/* @autodiff @print */ type =
755755
return;
756756
}
757757
}
758+
759+
760+
ad_test_twice:/* @autodiff @autodiff<"suffix=_d2"> @print */ type =
761+
{
762+
mul_1:(in x: double, ) -> (out r: double, ) =
763+
{
764+
r = x * x;
765+
return;
766+
}
767+
768+
mul_1_d:(
769+
in x: double,
770+
in x_d: double,
771+
) -> (
772+
out r: double = 0.0,
773+
out r_d: double = 0.0,
774+
) =
775+
{
776+
r_d = x * x_d + x * x_d;
777+
r = x * x;
778+
return;
779+
}
780+
781+
mul_1_d2:(
782+
in x: double,
783+
in x_d2: double,
784+
) -> (
785+
out r: double = 0.0,
786+
out r_d2: double = 0.0,
787+
) =
788+
{
789+
r_d2 = x * x_d2 + x * x_d2;
790+
r = x * x;
791+
return;
792+
}
793+
794+
mul_1_d_d2:(
795+
in x: double,
796+
in x_d2: double,
797+
in x_d: double,
798+
in x_d_d2: double,
799+
) -> (
800+
out r: double = 0.0,
801+
out r_d2: double = 0.0,
802+
out r_d: double = 0.0,
803+
out r_d_d2: double = 0.0,
804+
) =
805+
{
806+
temp_1_d2: _ = x * x_d_d2 + x_d * x_d2;
807+
temp_1: _ = x * x_d;
808+
temp_2_d2: _ = x * x_d_d2 + x_d * x_d2;
809+
temp_2: _ = x * x_d;
810+
r_d_d2 = temp_1_d2 + temp_2_d2;
811+
r_d = temp_1 + temp_2;
812+
r_d2 = x * x_d2 + x * x_d2;
813+
r = x * x;
814+
return;
815+
}
816+
}
758817
ok (all Cpp2, passes safety checks)
759818

0 commit comments

Comments
 (0)