1-
21ad_order : int == 6;
32ad_type : type == cpp2::taylor<double, ad_order>;
43
4+ ad_name: namespace = {
5+
6+ func_outer: (x: double, y: double) -> (ret: double) = {
7+ ret = x + y;
8+ }
9+
10+ type_outer: type = {
11+ public a: double = 0.0;
12+ }
13+
514ad_test: @autodiff<"order=6"> @print type = {
615
716 add_1: (x: double, y: double) -> (r: double) = {
@@ -60,6 +69,10 @@ ad_test: @autodiff<"order=6"> @print type = {
6069 r = x * func(x, y);
6170 }
6271
72+ func_outer_call: (x: double, y: double) -> (r: double) = {
73+ r = x * func_outer(x, y);
74+ }
75+
6376 sin_call: (x: double, y: double) -> (r: double) = {
6477 r = sin(x - y);
6578 }
@@ -153,6 +166,14 @@ ad_test: @autodiff<"order=6"> @print type = {
153166 r = r + t;
154167 }
155168 }
169+
170+ type_outer_use: (x: double, y: double) -> (r: double) = {
171+ t : type_outer = ();
172+ t.a = x;
173+
174+ r = t.a + y;
175+ }
176+ }
156177}
157178
158179write_output: (func: std::string, x: double, x_d: ad_type, y: double, y_d: ad_type, ret) = {
@@ -172,29 +193,31 @@ main: () = {
172193 y: double = 3.0;
173194 y_d: ad_type = 2.0;
174195
175- write_output("x + y", x, x_d, y, y_d, ad_test::add_1_d(x, x_d, y, y_d));
176- write_output("x + y + x", x, x_d, y, y_d, ad_test::add_2_d(x, x_d, y, y_d));
177- write_output("x - y", x, x_d, y, y_d, ad_test::sub_1_d(x, x_d, y, y_d));
178- write_output("x - y - x", x, x_d, y, y_d, ad_test::sub_2_d(x, x_d, y, y_d));
179- write_output("x + y - x", x, x_d, y, y_d, ad_test::add_sub_2_d(x, x_d, y, y_d));
180- write_output("x * y", x, x_d, y, y_d, ad_test::mul_1_d(x, x_d, y, y_d));
181- write_output("x * y * x", x, x_d, y, y_d, ad_test::mul_2_d(x, x_d, y, y_d));
182- write_output("x / y", x, x_d, y, y_d, ad_test::div_1_d(x, x_d, y, y_d));
183- write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_d(x, x_d, y, y_d));
184- write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_d(x, x_d, y, y_d));
185- write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_d(x, x_d, y, y_d));
186- write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
187- write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
188- write_output("sin(x - y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
189- write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
190- write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_d(x, x_d, y, y_d));
191- write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_d(x, x_d, y, y_d));
192- write_output("intermediate var", x, x_d, y, y_d, ad_test::intermediate_var_d(x, x_d, y, y_d));
193- write_output("intermediate passive var", x, x_d, y, y_d, ad_test::intermediate_passive_var_d(x, x_d, y, y_d));
194- write_output("intermediate untyped", x, x_d, y, y_d, ad_test::intermediate_untyped_d(x, x_d, y, y_d));
195- write_output("intermediate default init", x, x_d, y, y_d, ad_test::intermediate_default_init_d(x, x_d, y, y_d));
196- write_output("intermediate no init", x, x_d, y, y_d, ad_test::intermediate_no_init_d(x, x_d, y, y_d));
197- write_output("while loop", x, x_d, y, y_d, ad_test::while_loop_d(x, x_d, y, y_d));
198- write_output("do while loop", x, x_d, y, y_d, ad_test::do_while_loop_d(x, x_d, y, y_d));
199- write_output("for loop", x, x_d, y, y_d, ad_test::for_loop_d(x, x_d, y, y_d));
196+ write_output("x + y", x, x_d, y, y_d, ad_name::ad_test::add_1_d(x, x_d, y, y_d));
197+ write_output("x + y + x", x, x_d, y, y_d, ad_name::ad_test::add_2_d(x, x_d, y, y_d));
198+ write_output("x - y", x, x_d, y, y_d, ad_name::ad_test::sub_1_d(x, x_d, y, y_d));
199+ write_output("x - y - x", x, x_d, y, y_d, ad_name::ad_test::sub_2_d(x, x_d, y, y_d));
200+ write_output("x + y - x", x, x_d, y, y_d, ad_name::ad_test::add_sub_2_d(x, x_d, y, y_d));
201+ write_output("x * y", x, x_d, y, y_d, ad_name::ad_test::mul_1_d(x, x_d, y, y_d));
202+ write_output("x * y * x", x, x_d, y, y_d, ad_name::ad_test::mul_2_d(x, x_d, y, y_d));
203+ write_output("x / y", x, x_d, y, y_d, ad_name::ad_test::div_1_d(x, x_d, y, y_d));
204+ write_output("x / y / y", x, x_d, y, y_d, ad_name::ad_test::div_2_d(x, x_d, y, y_d));
205+ write_output("x * y / x", x, x_d, y, y_d, ad_name::ad_test::mul_div_2_d(x, x_d, y, y_d));
206+ write_output("x * (x + y)", x, x_d, y, y_d, ad_name::ad_test::mul_add_d(x, x_d, y, y_d));
207+ write_output("x + x * y", x, x_d, y, y_d, ad_name::ad_test::add_mul_d(x, x_d, y, y_d));
208+ write_output("x * func(x, y)", x, x_d, y, y_d, ad_name::ad_test::func_call_d(x, x_d, y, y_d));
209+ write_output("x * func_outer(x, y)", x, x_d, y, y_d, ad_name::ad_test::func_outer_call_d(x, x_d, y, y_d));
210+ write_output("sin(x - y)", x, x_d, y, y_d, ad_name::ad_test::sin_call_d(x, x_d, y, y_d));
211+ write_output("if branch", x, x_d, y, y_d, ad_name::ad_test::if_branch_d(x, x_d, y, y_d));
212+ write_output("if else branch", x, x_d, y, y_d, ad_name::ad_test::if_else_branch_d(x, x_d, y, y_d));
213+ write_output("direct return", x, x_d, y, y_d, ad_name::ad_test::direct_return_d(x, x_d, y, y_d));
214+ write_output("intermediate var", x, x_d, y, y_d, ad_name::ad_test::intermediate_var_d(x, x_d, y, y_d));
215+ write_output("intermediate passive var", x, x_d, y, y_d, ad_name::ad_test::intermediate_passive_var_d(x, x_d, y, y_d));
216+ write_output("intermediate untyped", x, x_d, y, y_d, ad_name::ad_test::intermediate_untyped_d(x, x_d, y, y_d));
217+ write_output("intermediate default init", x, x_d, y, y_d, ad_name::ad_test::intermediate_default_init_d(x, x_d, y, y_d));
218+ write_output("intermediate no init", x, x_d, y, y_d, ad_name::ad_test::intermediate_no_init_d(x, x_d, y, y_d));
219+ write_output("while loop", x, x_d, y, y_d, ad_name::ad_test::while_loop_d(x, x_d, y, y_d));
220+ write_output("do while loop", x, x_d, y, y_d, ad_name::ad_test::do_while_loop_d(x, x_d, y, y_d));
221+ write_output("for loop", x, x_d, y, y_d, ad_name::ad_test::for_loop_d(x, x_d, y, y_d));
222+ write_output("tye_outer.a + y", x, x_d, y, y_d, ad_name::ad_test::type_outer_use_d(x, x_d, y, y_d));
200223}
0 commit comments