Skip to content

Commit 8b5c1e9

Browse files
committed
Added handling for multiply and division.
1 parent 617a089 commit 8b5c1e9

File tree

6 files changed

+1032
-558
lines changed

6 files changed

+1032
-558
lines changed

regression-tests/pure2-autodiff-higher-order.cpp2

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,34 @@ ad_test: @autodiff<"order=6"> @print type = {
2323
add_sub_2: (x: double, y: double) -> (r: double) = {
2424
r = x + y - x;
2525
}
26+
27+
mul_1: (x: double, y: double) -> (r: double) = {
28+
r = x * y;
29+
}
30+
31+
mul_2: (x: double, y: double) -> (r: double) = {
32+
r = x * y * x;
33+
}
34+
35+
div_1: (x: double, y: double) -> (r: double) = {
36+
r = x / y;
37+
}
38+
39+
div_2: (x: double, y: double) -> (r: double) = {
40+
r = x / y / y;
41+
}
42+
43+
mul_div_2: (x: double, y: double) -> (r: double) = {
44+
r = x * y / x;
45+
}
46+
47+
mul_add: (x: double, y: double) -> (r: double) = {
48+
r = x * (x + y);
49+
}
50+
51+
add_mul: (x: double, y: double) -> (r: double) = {
52+
r = x + x * y;
53+
}
2654
}
2755

2856
write_output: (func: std::string, x: double, x_d: ad_type, y: double, y_d: ad_type, ret) = {
@@ -47,13 +75,13 @@ main: () = {
4775
write_output("x - y", x, x_d, y, y_d, ad_test::sub_1_d(x, x_d, y, y_d));
4876
write_output("x - y - x", x, x_d, y, y_d, ad_test::sub_2_d(x, x_d, y, y_d));
4977
write_output("x + y - x", x, x_d, y, y_d, ad_test::add_sub_2_d(x, x_d, y, y_d));
50-
// write_output("x * y", x, x_d, y, y_d, ad_test::mul_1_d(x, x_d, y, y_d));
51-
// write_output("x * y * x", x, x_d, y, y_d, ad_test::mul_2_d(x, x_d, y, y_d));
52-
// write_output("x / y", x, x_d, y, y_d, ad_test::div_1_d(x, x_d, y, y_d));
53-
// write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_d(x, x_d, y, y_d));
54-
// write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_d(x, x_d, y, y_d));
55-
// write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_d(x, x_d, y, y_d));
56-
// write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
78+
write_output("x * y", x, x_d, y, y_d, ad_test::mul_1_d(x, x_d, y, y_d));
79+
write_output("x * y * x", x, x_d, y, y_d, ad_test::mul_2_d(x, x_d, y, y_d));
80+
write_output("x / y", x, x_d, y, y_d, ad_test::div_1_d(x, x_d, y, y_d));
81+
write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_d(x, x_d, y, y_d));
82+
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_d(x, x_d, y, y_d));
83+
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_d(x, x_d, y, y_d));
84+
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
5785
// write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
5886
// write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
5987
// write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff-higher-order.cpp.execution

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,59 @@ diff(x + y - x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.
3838
d4 = 0.000000
3939
d5 = 0.000000
4040
d6 = 0.000000
41+
diff(x * y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
42+
r = 6.000000
43+
d1 = 7.000000
44+
d2 = 4.000000
45+
d3 = 0.000000
46+
d4 = 0.000000
47+
d5 = 0.000000
48+
d6 = 0.000000
49+
diff(x * y * x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
50+
r = 12.000000
51+
d1 = 20.000000
52+
d2 = 22.000000
53+
d3 = 12.000000
54+
d4 = 0.000000
55+
d5 = 0.000000
56+
d6 = 0.000000
57+
diff(x / y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
58+
r = 0.666667
59+
d1 = -0.111111
60+
d2 = 0.148148
61+
d3 = -0.296296
62+
d4 = 0.790123
63+
d5 = -2.633745
64+
d6 = 10.534979
65+
diff(x / y / y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
66+
r = 0.222222
67+
d1 = -0.185185
68+
d2 = 0.296296
69+
d3 = -0.691358
70+
d4 = 2.106996
71+
d5 = -7.901235
72+
d6 = 35.116598
73+
diff(x * y / x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
74+
r = 3.000000
75+
d1 = 2.000000
76+
d2 = 0.000000
77+
d3 = 0.000000
78+
d4 = 0.000000
79+
d5 = 0.000000
80+
d6 = 0.000000
81+
diff(x * (x + y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
82+
r = 10.000000
83+
d1 = 11.000000
84+
d2 = 6.000000
85+
d3 = 0.000000
86+
d4 = 0.000000
87+
d5 = 0.000000
88+
d6 = 0.000000
89+
diff(x + x * y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
90+
r = 8.000000
91+
d1 = 8.000000
92+
d2 = 4.000000
93+
d3 = 0.000000
94+
d4 = 0.000000
95+
d5 = 0.000000
96+
d6 = 0.000000

0 commit comments

Comments
 (0)