Skip to content

Commit b930fe7

Browse files
committed
Remaining tests for higher order derivatives.
1 parent 05cff29 commit b930fe7

File tree

10 files changed

+1526
-550
lines changed

10 files changed

+1526
-550
lines changed

include/cpp2taylor.h

Lines changed: 98 additions & 66 deletions
Large diffs are not rendered by default.

include/cpp2taylor.h2

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ taylor: <R, dim: int> type = {
1212
}
1313
operator=:(out this, that) = {}
1414

15+
operator=:(out this, l: std::initializer_list<R>) = {
16+
(copy i := 1)
17+
for l do (cur) {
18+
set(i, cur);
19+
}
20+
}
21+
1522
// C++ interface
1623

1724
operator[]: (this, k: int) -> R = {

regression-tests/pure2-autodiff-higher-order.cpp2

Lines changed: 101 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,96 @@ ad_test: @autodiff<"order=6"> @print type = {
6363
sin_call: (x: double, y: double) -> (r: double) = {
6464
r = sin(x - y);
6565
}
66+
67+
if_branch: (x: double, y: double) -> (r: double) = {
68+
r = x;
69+
70+
if x < 0.0 {
71+
r = y;
72+
}
73+
}
74+
75+
if_else_branch: (x: double, y: double) -> (r: double) = {
76+
if x < 0.0 {
77+
r = y;
78+
}
79+
else {
80+
r = x;
81+
}
82+
}
83+
84+
direct_return: (x: double, y: double) -> double = {
85+
return x + y;
86+
}
87+
88+
intermediate_var: (x: double, y: double) -> (r: double) = {
89+
t: double = x + y;
90+
91+
r = t;
92+
}
93+
94+
intermediate_passive_var: (x: double, y: double) -> (r: double) = {
95+
i: int = (); // TODO: Handle as passive when type information on call side is available.
96+
r = x + y;
97+
i = 2;
98+
99+
_ = i;
100+
}
101+
102+
intermediate_untyped: (x: double, y: double) -> (r: double) = {
103+
t := 0.0;
104+
t = x + y;
105+
106+
r = t;
107+
}
108+
109+
intermediate_default_init: (x: double, y: double) -> (r: double) = {
110+
t: double = ();
111+
t = x + y;
112+
113+
r = t;
114+
}
115+
116+
intermediate_no_init: (x: double, y: double) -> (r: double) = {
117+
t: double;
118+
t = x + y;
119+
120+
r = t;
121+
}
122+
123+
while_loop: (x: double, y: double) -> (r: double) = {
124+
i: int = 0;
125+
126+
r = x;
127+
while i < 2 next (i += 1) {
128+
r = r + y ;
129+
}
130+
}
131+
132+
do_while_loop: (x: double, y: double) -> (r: double) = {
133+
i: int = 0;
134+
135+
r = x;
136+
do {
137+
r = r + y ;
138+
}
139+
next (i += 1)
140+
while i < 2;
141+
}
142+
143+
for_loop: (x: double, y: double) -> (r: double) = {
144+
v: std::vector<double> = ();
145+
146+
v.push_back(x);
147+
v.push_back(y);
148+
149+
r = 0.0;
150+
for v
151+
do (t)
152+
{
153+
r = r + t;
154+
}
155+
}
66156
}
67157

68158
write_output: (func: std::string, x: double, x_d: ad_type, y: double, y_d: ad_type, ret) = {
@@ -96,15 +186,15 @@ main: () = {
96186
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
97187
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
98188
write_output("sin(x - y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
99-
// write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
100-
// write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_d(x, x_d, y, y_d));
101-
// write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_d(x, x_d, y, y_d));
102-
// write_output("intermediate var", x, x_d, y, y_d, ad_test::intermediate_var_d(x, x_d, y, y_d));
103-
// write_output("intermediate passive var", x, x_d, y, y_d, ad_test::intermediate_passive_var_d(x, x_d, y, y_d));
104-
// write_output("intermediate untyped", x, x_d, y, y_d, ad_test::intermediate_untyped_d(x, x_d, y, y_d));
105-
// write_output("intermediate default init", x, x_d, y, y_d, ad_test::intermediate_default_init_d(x, x_d, y, y_d));
106-
// write_output("intermediate no init", x, x_d, y, y_d, ad_test::intermediate_no_init_d(x, x_d, y, y_d));
107-
// write_output("while loop", x, x_d, y, y_d, ad_test::while_loop_d(x, x_d, y, y_d));
108-
// write_output("do while loop", x, x_d, y, y_d, ad_test::do_while_loop_d(x, x_d, y, y_d));
109-
// write_output("for loop", x, x_d, y, y_d, ad_test::for_loop_d(x, x_d, y, y_d));
189+
write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
190+
write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_d(x, x_d, y, y_d));
191+
write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_d(x, x_d, y, y_d));
192+
write_output("intermediate var", x, x_d, y, y_d, ad_test::intermediate_var_d(x, x_d, y, y_d));
193+
write_output("intermediate passive var", x, x_d, y, y_d, ad_test::intermediate_passive_var_d(x, x_d, y, y_d));
194+
write_output("intermediate untyped", x, x_d, y, y_d, ad_test::intermediate_untyped_d(x, x_d, y, y_d));
195+
write_output("intermediate default init", x, x_d, y, y_d, ad_test::intermediate_default_init_d(x, x_d, y, y_d));
196+
write_output("intermediate no init", x, x_d, y, y_d, ad_test::intermediate_no_init_d(x, x_d, y, y_d));
197+
write_output("while loop", x, x_d, y, y_d, ad_test::while_loop_d(x, x_d, y, y_d));
198+
write_output("do while loop", x, x_d, y, y_d, ad_test::do_while_loop_d(x, x_d, y, y_d));
199+
write_output("for loop", x, x_d, y, y_d, ad_test::for_loop_d(x, x_d, y, y_d));
110200
}

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff-higher-order.cpp.execution

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,91 @@ diff(sin(x - y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0
110110
d4 = -0.841471
111111
d5 = -0.540302
112112
d6 = 0.841471
113+
diff(if branch) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
114+
r = 2.000000
115+
d1 = 1.000000
116+
d2 = 0.000000
117+
d3 = 0.000000
118+
d4 = 0.000000
119+
d5 = 0.000000
120+
d6 = 0.000000
121+
diff(if else branch) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
122+
r = 2.000000
123+
d1 = 1.000000
124+
d2 = 0.000000
125+
d3 = 0.000000
126+
d4 = 0.000000
127+
d5 = 0.000000
128+
d6 = 0.000000
129+
diff(direct return) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
130+
r = 5.000000
131+
d1 = 3.000000
132+
d2 = 0.000000
133+
d3 = 0.000000
134+
d4 = 0.000000
135+
d5 = 0.000000
136+
d6 = 0.000000
137+
diff(intermediate var) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
138+
r = 5.000000
139+
d1 = 3.000000
140+
d2 = 0.000000
141+
d3 = 0.000000
142+
d4 = 0.000000
143+
d5 = 0.000000
144+
d6 = 0.000000
145+
diff(intermediate passive var) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
146+
r = 5.000000
147+
d1 = 3.000000
148+
d2 = 0.000000
149+
d3 = 0.000000
150+
d4 = 0.000000
151+
d5 = 0.000000
152+
d6 = 0.000000
153+
diff(intermediate untyped) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
154+
r = 5.000000
155+
d1 = 3.000000
156+
d2 = 0.000000
157+
d3 = 0.000000
158+
d4 = 0.000000
159+
d5 = 0.000000
160+
d6 = 0.000000
161+
diff(intermediate default init) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
162+
r = 5.000000
163+
d1 = 3.000000
164+
d2 = 0.000000
165+
d3 = 0.000000
166+
d4 = 0.000000
167+
d5 = 0.000000
168+
d6 = 0.000000
169+
diff(intermediate no init) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
170+
r = 5.000000
171+
d1 = 3.000000
172+
d2 = 0.000000
173+
d3 = 0.000000
174+
d4 = 0.000000
175+
d5 = 0.000000
176+
d6 = 0.000000
177+
diff(while loop) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
178+
r = 8.000000
179+
d1 = 5.000000
180+
d2 = 0.000000
181+
d3 = 0.000000
182+
d4 = 0.000000
183+
d5 = 0.000000
184+
d6 = 0.000000
185+
diff(do while loop) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
186+
r = 8.000000
187+
d1 = 5.000000
188+
d2 = 0.000000
189+
d3 = 0.000000
190+
d4 = 0.000000
191+
d5 = 0.000000
192+
d6 = 0.000000
193+
diff(for loop) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
194+
r = 5.000000
195+
d1 = 3.000000
196+
d2 = 0.000000
197+
d3 = 0.000000
198+
d4 = 0.000000
199+
d5 = 0.000000
200+
d6 = 0.000000

0 commit comments

Comments
 (0)