@@ -197,9 +197,49 @@ ad_test_reverse: @autodiff<"reverse"> @print type = {
197197 r = x + y;
198198 }
199199
200+ add_2: (x: double, y: double) -> (r: double) = {
201+ r = x + y + x;
202+ }
203+
204+ sub_1: (x: double, y: double) -> (r: double) = {
205+ r = x - y;
206+ }
207+
208+ sub_2: (x: double, y: double) -> (r: double) = {
209+ r = x - y - x;
210+ }
211+
212+ add_sub_2: (x: double, y: double) -> (r: double) = {
213+ r = x + y - x;
214+ }
215+
200216 mul_1: (x: double, y: double) -> (r: double) = {
201217 r = x * y;
202218 }
219+
220+ mul_2: (x: double, y: double) -> (r: double) = {
221+ r = x * y * x;
222+ }
223+
224+ div_1: (x: double, y: double) -> (r: double) = {
225+ r = x / y;
226+ }
227+
228+ div_2: (x: double, y: double) -> (r: double) = {
229+ r = x / y / y;
230+ }
231+
232+ mul_div_2: (x: double, y: double) -> (r: double) = {
233+ r = x * y / x;
234+ }
235+
236+ mul_add: (x: double, y: double) -> (r: double) = {
237+ r = x * (x + y);
238+ }
239+
240+ add_mul: (x: double, y: double) -> (r: double) = {
241+ r = x + x * y;
242+ }
203243}
204244}
205245
@@ -213,7 +253,8 @@ write_output: (func: std::string, x: double, x_d: double, y: double, y_d: double
213253 std::cout << "diff((func)$) at (x = (x)$, x_d = (x_d)$, y = (y)$, y_d = (y_d)$) = (r = (ret.r)$, r_d = (ret.r_d)$)" << std::endl;
214254}
215255
216- write_output_reverse: (func: std::string, x: double, inout x_b: double, y: double, inout y_b: double, in r_b: double, ret) = {
256+ write_output_reverse: (func: std::string, x: double, inout x_b: double, y: double, inout y_b: double, inout r_b: double, ret) = {
257+ r_b = 1.0;
217258 std::cout << "diff((func)$) at (x = (x)$, y = (y)$, r_b = (r_b)$) = (r = (ret)$, x_b = (x_b)$, y_b = (y_b)$)" << std::endl;
218259 x_b = 0.0;
219260 y_b = 0.0;
@@ -262,7 +303,17 @@ main: () = {
262303 w_b: double = 1.0;
263304
264305 write_output_reverse("x + y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_1_b(x, x_b, y, y_b, w_b));
306+ write_output_reverse("x + y + x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_2_b(x, x_b, y, y_b, w_b));
307+ write_output_reverse("x - y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::sub_1_b(x, x_b, y, y_b, w_b));
308+ write_output_reverse("x - y - x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::sub_2_b(x, x_b, y, y_b, w_b));
309+ write_output_reverse("x + y - x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_sub_2_b(x, x_b, y, y_b, w_b));
265310 write_output_reverse("x * y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_1_b(x, x_b, y, y_b, w_b));
311+ write_output_reverse("x * y * x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_2_b(x, x_b, y, y_b, w_b));
312+ write_output_reverse("x / y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::div_1_b(x, x_b, y, y_b, w_b));
313+ write_output_reverse("x / y / y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::div_2_b(x, x_b, y, y_b, w_b));
314+ write_output_reverse("x * y / x", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_div_2_b(x, x_b, y, y_b, w_b));
315+ write_output_reverse("x * (x + y)", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::mul_add_b(x, x_b, y, y_b, w_b));
316+ write_output_reverse("x + x * y", x, x_b, y, y_b, w_b, ad_name::ad_test_reverse::add_mul_b(x, x_b, y, y_b, w_b));
266317
267318 _ = x_b;
268319 _ = y_b;
0 commit comments