@@ -23,6 +23,34 @@ ad_test: @autodiff<"order=6"> @print type = {
2323 add_sub_2: (x: double, y: double) -> (r: double) = {
2424 r = x + y - x;
2525 }
26+
27+ mul_1: (x: double, y: double) -> (r: double) = {
28+ r = x * y;
29+ }
30+
31+ mul_2: (x: double, y: double) -> (r: double) = {
32+ r = x * y * x;
33+ }
34+
35+ div_1: (x: double, y: double) -> (r: double) = {
36+ r = x / y;
37+ }
38+
39+ div_2: (x: double, y: double) -> (r: double) = {
40+ r = x / y / y;
41+ }
42+
43+ mul_div_2: (x: double, y: double) -> (r: double) = {
44+ r = x * y / x;
45+ }
46+
47+ mul_add: (x: double, y: double) -> (r: double) = {
48+ r = x * (x + y);
49+ }
50+
51+ add_mul: (x: double, y: double) -> (r: double) = {
52+ r = x + x * y;
53+ }
2654}
2755
2856write_output: (func: std::string, x: double, x_d: ad_type, y: double, y_d: ad_type, ret) = {
@@ -47,13 +75,13 @@ main: () = {
4775 write_output("x - y", x, x_d, y, y_d, ad_test::sub_1_d(x, x_d, y, y_d));
4876 write_output("x - y - x", x, x_d, y, y_d, ad_test::sub_2_d(x, x_d, y, y_d));
4977 write_output("x + y - x", x, x_d, y, y_d, ad_test::add_sub_2_d(x, x_d, y, y_d));
50- // write_output("x * y", x, x_d, y, y_d, ad_test::mul_1_d(x, x_d, y, y_d));
51- // write_output("x * y * x", x, x_d, y, y_d, ad_test::mul_2_d(x, x_d, y, y_d));
52- // write_output("x / y", x, x_d, y, y_d, ad_test::div_1_d(x, x_d, y, y_d));
53- // write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_d(x, x_d, y, y_d));
54- // write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_d(x, x_d, y, y_d));
55- // write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_d(x, x_d, y, y_d));
56- // write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
78+ write_output("x * y", x, x_d, y, y_d, ad_test::mul_1_d(x, x_d, y, y_d));
79+ write_output("x * y * x", x, x_d, y, y_d, ad_test::mul_2_d(x, x_d, y, y_d));
80+ write_output("x / y", x, x_d, y, y_d, ad_test::div_1_d(x, x_d, y, y_d));
81+ write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_d(x, x_d, y, y_d));
82+ write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_d(x, x_d, y, y_d));
83+ write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_d(x, x_d, y, y_d));
84+ write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
5785// write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
5886// write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
5987// write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
0 commit comments