Skip to content

Commit 00be8ff

Browse files
committed
Handling if/else statements.
1 parent 8166f0f commit 00be8ff

File tree

6 files changed

+844
-459
lines changed

6 files changed

+844
-459
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,23 @@ ad_test: @autodiff @print type = {
6060
sin_call: (x: double, y: double) -> (r: double) = {
6161
r = sin(x - y);
6262
}
63+
64+
if_branch: (x: double, y: double) -> (r: double) = {
65+
r = x;
66+
67+
if x < 0.0 {
68+
r = y;
69+
}
70+
}
71+
72+
if_else_branch: (x: double, y: double) -> (r: double) = {
73+
if x < 0.0 {
74+
r = y;
75+
}
76+
else {
77+
r = x;
78+
}
79+
}
6380
}
6481

6582
write_output: (func: std::string, x: double, x_d: double, y: double, y_d: double, ret) = {
@@ -87,4 +104,6 @@ main: () = {
87104
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_diff(x, x_d, y, y_d));
88105
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_diff(x, x_d, y, y_d));
89106
write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_diff(x, x_d, y, y_d));
107+
write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_diff(x, x_d, y, y_d));
108+
write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_diff(x, x_d, y, y_d));
90109
}

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ diff(x * (x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000
1212
diff(x + x * y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 8.000000, r_d = 8.000000)
1313
diff(x * func(x, y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)
1414
diff(sin(x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = -0.841471, r_d = -0.540302)
15+
diff(if branch) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 2.000000, r_d = 1.000000)
16+
diff(if else branch) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 2.000000, r_d = 1.000000)

0 commit comments

Comments
 (0)