Skip to content

Commit 247a79d

Browse files
committed
Added type differentiation for types without member functions.
1 parent bc3597a commit 247a79d

File tree

8 files changed

+1045
-759
lines changed

8 files changed

+1045
-759
lines changed

regression-tests/pure2-autodiff-higher-order.cpp2

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1-
21
ad_order : int == 6;
32
ad_type : type == cpp2::taylor<double, ad_order>;
43

4+
ad_name: namespace = {
5+
6+
func_outer: (x: double, y: double) -> (ret: double) = {
7+
ret = x + y;
8+
}
9+
10+
type_outer: type = {
11+
public a: double = 0.0;
12+
}
13+
514
ad_test: @autodiff<"order=6"> @print type = {
615

716
add_1: (x: double, y: double) -> (r: double) = {
@@ -60,6 +69,10 @@ ad_test: @autodiff<"order=6"> @print type = {
6069
r = x * func(x, y);
6170
}
6271

72+
func_outer_call: (x: double, y: double) -> (r: double) = {
73+
r = x * func_outer(x, y);
74+
}
75+
6376
sin_call: (x: double, y: double) -> (r: double) = {
6477
r = sin(x - y);
6578
}
@@ -153,6 +166,14 @@ ad_test: @autodiff<"order=6"> @print type = {
153166
r = r + t;
154167
}
155168
}
169+
170+
type_outer_use: (x: double, y: double) -> (r: double) = {
171+
t : type_outer = ();
172+
t.a = x;
173+
174+
r = t.a + y;
175+
}
176+
}
156177
}
157178

158179
write_output: (func: std::string, x: double, x_d: ad_type, y: double, y_d: ad_type, ret) = {
@@ -172,29 +193,31 @@ main: () = {
172193
y: double = 3.0;
173194
y_d: ad_type = 2.0;
174195

175-
write_output("x + y", x, x_d, y, y_d, ad_test::add_1_d(x, x_d, y, y_d));
176-
write_output("x + y + x", x, x_d, y, y_d, ad_test::add_2_d(x, x_d, y, y_d));
177-
write_output("x - y", x, x_d, y, y_d, ad_test::sub_1_d(x, x_d, y, y_d));
178-
write_output("x - y - x", x, x_d, y, y_d, ad_test::sub_2_d(x, x_d, y, y_d));
179-
write_output("x + y - x", x, x_d, y, y_d, ad_test::add_sub_2_d(x, x_d, y, y_d));
180-
write_output("x * y", x, x_d, y, y_d, ad_test::mul_1_d(x, x_d, y, y_d));
181-
write_output("x * y * x", x, x_d, y, y_d, ad_test::mul_2_d(x, x_d, y, y_d));
182-
write_output("x / y", x, x_d, y, y_d, ad_test::div_1_d(x, x_d, y, y_d));
183-
write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_d(x, x_d, y, y_d));
184-
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_d(x, x_d, y, y_d));
185-
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_d(x, x_d, y, y_d));
186-
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
187-
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
188-
write_output("sin(x - y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
189-
write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
190-
write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_d(x, x_d, y, y_d));
191-
write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_d(x, x_d, y, y_d));
192-
write_output("intermediate var", x, x_d, y, y_d, ad_test::intermediate_var_d(x, x_d, y, y_d));
193-
write_output("intermediate passive var", x, x_d, y, y_d, ad_test::intermediate_passive_var_d(x, x_d, y, y_d));
194-
write_output("intermediate untyped", x, x_d, y, y_d, ad_test::intermediate_untyped_d(x, x_d, y, y_d));
195-
write_output("intermediate default init", x, x_d, y, y_d, ad_test::intermediate_default_init_d(x, x_d, y, y_d));
196-
write_output("intermediate no init", x, x_d, y, y_d, ad_test::intermediate_no_init_d(x, x_d, y, y_d));
197-
write_output("while loop", x, x_d, y, y_d, ad_test::while_loop_d(x, x_d, y, y_d));
198-
write_output("do while loop", x, x_d, y, y_d, ad_test::do_while_loop_d(x, x_d, y, y_d));
199-
write_output("for loop", x, x_d, y, y_d, ad_test::for_loop_d(x, x_d, y, y_d));
196+
write_output("x + y", x, x_d, y, y_d, ad_name::ad_test::add_1_d(x, x_d, y, y_d));
197+
write_output("x + y + x", x, x_d, y, y_d, ad_name::ad_test::add_2_d(x, x_d, y, y_d));
198+
write_output("x - y", x, x_d, y, y_d, ad_name::ad_test::sub_1_d(x, x_d, y, y_d));
199+
write_output("x - y - x", x, x_d, y, y_d, ad_name::ad_test::sub_2_d(x, x_d, y, y_d));
200+
write_output("x + y - x", x, x_d, y, y_d, ad_name::ad_test::add_sub_2_d(x, x_d, y, y_d));
201+
write_output("x * y", x, x_d, y, y_d, ad_name::ad_test::mul_1_d(x, x_d, y, y_d));
202+
write_output("x * y * x", x, x_d, y, y_d, ad_name::ad_test::mul_2_d(x, x_d, y, y_d));
203+
write_output("x / y", x, x_d, y, y_d, ad_name::ad_test::div_1_d(x, x_d, y, y_d));
204+
write_output("x / y / y", x, x_d, y, y_d, ad_name::ad_test::div_2_d(x, x_d, y, y_d));
205+
write_output("x * y / x", x, x_d, y, y_d, ad_name::ad_test::mul_div_2_d(x, x_d, y, y_d));
206+
write_output("x * (x + y)", x, x_d, y, y_d, ad_name::ad_test::mul_add_d(x, x_d, y, y_d));
207+
write_output("x + x * y", x, x_d, y, y_d, ad_name::ad_test::add_mul_d(x, x_d, y, y_d));
208+
write_output("x * func(x, y)", x, x_d, y, y_d, ad_name::ad_test::func_call_d(x, x_d, y, y_d));
209+
write_output("x * func_outer(x, y)", x, x_d, y, y_d, ad_name::ad_test::func_outer_call_d(x, x_d, y, y_d));
210+
write_output("sin(x - y)", x, x_d, y, y_d, ad_name::ad_test::sin_call_d(x, x_d, y, y_d));
211+
write_output("if branch", x, x_d, y, y_d, ad_name::ad_test::if_branch_d(x, x_d, y, y_d));
212+
write_output("if else branch", x, x_d, y, y_d, ad_name::ad_test::if_else_branch_d(x, x_d, y, y_d));
213+
write_output("direct return", x, x_d, y, y_d, ad_name::ad_test::direct_return_d(x, x_d, y, y_d));
214+
write_output("intermediate var", x, x_d, y, y_d, ad_name::ad_test::intermediate_var_d(x, x_d, y, y_d));
215+
write_output("intermediate passive var", x, x_d, y, y_d, ad_name::ad_test::intermediate_passive_var_d(x, x_d, y, y_d));
216+
write_output("intermediate untyped", x, x_d, y, y_d, ad_name::ad_test::intermediate_untyped_d(x, x_d, y, y_d));
217+
write_output("intermediate default init", x, x_d, y, y_d, ad_name::ad_test::intermediate_default_init_d(x, x_d, y, y_d));
218+
write_output("intermediate no init", x, x_d, y, y_d, ad_name::ad_test::intermediate_no_init_d(x, x_d, y, y_d));
219+
write_output("while loop", x, x_d, y, y_d, ad_name::ad_test::while_loop_d(x, x_d, y, y_d));
220+
write_output("do while loop", x, x_d, y, y_d, ad_name::ad_test::do_while_loop_d(x, x_d, y, y_d));
221+
write_output("for loop", x, x_d, y, y_d, ad_name::ad_test::for_loop_d(x, x_d, y, y_d));
222+
write_output("tye_outer.a + y", x, x_d, y, y_d, ad_name::ad_test::type_outer_use_d(x, x_d, y, y_d));
200223
}

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff-higher-order.cpp.execution

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ diff(x * func(x, y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.0000
102102
d4 = 0.000000
103103
d5 = 0.000000
104104
d6 = 0.000000
105+
diff(x * func_outer(x, y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
106+
r = 10.000000
107+
d1 = 11.000000
108+
d2 = 6.000000
109+
d3 = 0.000000
110+
d4 = 0.000000
111+
d5 = 0.000000
112+
d6 = 0.000000
105113
diff(sin(x - y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
106114
r = -0.841471
107115
d1 = -0.540302
@@ -198,3 +206,11 @@ diff(for loop) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.0
198206
d4 = 0.000000
199207
d5 = 0.000000
200208
d6 = 0.000000
209+
diff(tye_outer.a + y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
210+
r = 5.000000
211+
d1 = 3.000000
212+
d2 = 0.000000
213+
d3 = 0.000000
214+
d4 = 0.000000
215+
d5 = 0.000000
216+
d6 = 0.000000

0 commit comments

Comments
 (0)