Skip to content

Commit 7b3b9c8

Browse files
committed
Added tests for combined binary expressions.
1 parent e0b4a6a commit 7b3b9c8

File tree

6 files changed

+1153
-545
lines changed

6 files changed

+1153
-545
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,49 @@ ad_test_reverse: @autodiff<"reverse"> @print type = {
197197
r = x + y;
198198
}
199199

200+
add_2: (x: double, y: double) -> (r: double) = {
201+
r = x + y + x;
202+
}
203+
204+
sub_1: (x: double, y: double) -> (r: double) = {
205+
r = x - y;
206+
}
207+
208+
sub_2: (x: double, y: double) -> (r: double) = {
209+
r = x - y - x;
210+
}
211+
212+
add_sub_2: (x: double, y: double) -> (r: double) = {
213+
r = x + y - x;
214+
}
215+
200216
mul_1: (x: double, y: double) -> (r: double) = {
201217
r = x * y;
202218
}
219+
220+
mul_2: (x: double, y: double) -> (r: double) = {
221+
r = x * y * x;
222+
}
223+
224+
div_1: (x: double, y: double) -> (r: double) = {
225+
r = x / y;
226+
}
227+
228+
div_2: (x: double, y: double) -> (r: double) = {
229+
r = x / y / y;
230+
}
231+
232+
mul_div_2: (x: double, y: double) -> (r: double) = {
233+
r = x * y / x;
234+
}
235+
236+
mul_add: (x: double, y: double) -> (r: double) = {
237+
r = x * (x + y);
238+
}
239+
240+
add_mul: (x: double, y: double) -> (r: double) = {
241+
r = x + x * y;
242+
}
203243
}
204244
}
205245

@@ -213,7 +253,8 @@ write_output: (func: std::string, x: double, x_d: double, y: double, y_d: double
213253
std::cout << "diff((func)$) at (x = (x)$, x_d = (x_d)$, y = (y)$, y_d = (y_d)$) = (r = (ret.r)$, r_d = (ret.r_d)$)" << std::endl;
214254
}
215255

216-
write_output_reverse: (func: std::string, x: double, inout x_b: double, y: double, inout y_b: double, in r_b: double, ret) = {
256+
write_output_reverse: (func: std::string, x: double, inout x_b: double, y: double, inout y_b: double, inout r_b: double, ret) = {
257+
r_b = 1.0;
217258
std::cout << "diff((func)$) at (x = (x)$, y = (y)$, r_b = (r_b)$) = (r = (ret)$, x_b = (x_b)$, y_b = (y_b)$)" << std::endl;
218259
x_b = 0.0;
219260
y_b = 0.0;
@@ -262,7 +303,17 @@ main: () = {
262303
w_b: double = 1.0;
263304

264305
write_output_reverse("x + y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_1_b(x, x_b, y, y_b, w_b));
306+
write_output_reverse("x + y + x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_2_b(x, x_b, y, y_b, w_b));
307+
write_output_reverse("x - y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::sub_1_b(x, x_b, y, y_b, w_b));
308+
write_output_reverse("x - y - x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::sub_2_b(x, x_b, y, y_b, w_b));
309+
write_output_reverse("x + y - x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_sub_2_b(x, x_b, y, y_b, w_b));
265310
write_output_reverse("x * y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_1_b(x, x_b, y, y_b, w_b));
311+
write_output_reverse("x * y * x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_2_b(x, x_b, y, y_b, w_b));
312+
write_output_reverse("x / y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::div_1_b(x, x_b, y, y_b, w_b));
313+
write_output_reverse("x / y / y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::div_2_b(x, x_b, y, y_b, w_b));
314+
write_output_reverse("x * y / x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_div_2_b(x, x_b, y, y_b, w_b));
315+
write_output_reverse("x * (x + y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_add_b(x, x_b, y, y_b, w_b));
316+
write_output_reverse("x + x * y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_mul_b(x, x_b, y, y_b, w_b));
266317

267318
_ = x_b;
268319
_ = y_b;

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,15 @@ diff(for loop) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) =
2929
diff(tye_outer.a + y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 5.000000, r_d = 3.000000)
3030
diff(type_outer.add(y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 5.000000, r_d = 3.000000)
3131
diff(x + y) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 5.000000, x_b = 1.000000, y_b = 1.000000)
32+
diff(x + y + x) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 7.000000, x_b = 2.000000, y_b = 1.000000)
33+
diff(x - y) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = -1.000000, x_b = 1.000000, y_b = -1.000000)
34+
diff(x - y - x) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = -3.000000, x_b = 0.000000, y_b = -1.000000)
35+
diff(x + y - x) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 3.000000, x_b = 0.000000, y_b = 1.000000)
3236
diff(x * y) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 6.000000, x_b = 3.000000, y_b = 2.000000)
37+
diff(x * y * x) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 12.000000, x_b = 12.000000, y_b = 4.000000)
38+
diff(x / y) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 0.666667, x_b = 0.333333, y_b = -0.222222)
39+
diff(x / y / y) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 0.222222, x_b = 0.111111, y_b = -0.148148)
40+
diff(x * y / x) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 3.000000, x_b = 0.000000, y_b = 1.000000)
41+
diff(x * (x + y)) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 10.000000, x_b = 7.000000, y_b = 2.000000)
42+
diff(x + x * y) at (x = 2.000000, y = 3.000000, r_b = 1.000000) = (r = 8.000000, x_b = 4.000000, y_b = 2.000000)
3343
2nd order diff of x*x at 2.000000 = 2.000000

0 commit comments

Comments
 (0)