@@ -63,6 +63,96 @@ ad_test: @autodiff<"order=6"> @print type = {
6363 sin_call: (x: double, y: double) -> (r: double) = {
6464 r = sin(x - y);
6565 }
66+
67+ if_branch: (x: double, y: double) -> (r: double) = {
68+ r = x;
69+
70+ if x < 0.0 {
71+ r = y;
72+ }
73+ }
74+
75+ if_else_branch: (x: double, y: double) -> (r: double) = {
76+ if x < 0.0 {
77+ r = y;
78+ }
79+ else {
80+ r = x;
81+ }
82+ }
83+
84+ direct_return: (x: double, y: double) -> double = {
85+ return x + y;
86+ }
87+
88+ intermediate_var: (x: double, y: double) -> (r: double) = {
89+ t: double = x + y;
90+
91+ r = t;
92+ }
93+
94+ intermediate_passive_var: (x: double, y: double) -> (r: double) = {
95+ i: int = (); // TODO: Handle as passive when type information on call side is available.
96+ r = x + y;
97+ i = 2;
98+
99+ _ = i;
100+ }
101+
102+ intermediate_untyped: (x: double, y: double) -> (r: double) = {
103+ t := 0.0;
104+ t = x + y;
105+
106+ r = t;
107+ }
108+
109+ intermediate_default_init: (x: double, y: double) -> (r: double) = {
110+ t: double = ();
111+ t = x + y;
112+
113+ r = t;
114+ }
115+
116+ intermediate_no_init: (x: double, y: double) -> (r: double) = {
117+ t: double;
118+ t = x + y;
119+
120+ r = t;
121+ }
122+
123+ while_loop: (x: double, y: double) -> (r: double) = {
124+ i: int = 0;
125+
126+ r = x;
127+ while i < 2 next (i += 1) {
128+ r = r + y ;
129+ }
130+ }
131+
132+ do_while_loop: (x: double, y: double) -> (r: double) = {
133+ i: int = 0;
134+
135+ r = x;
136+ do {
137+ r = r + y ;
138+ }
139+ next (i += 1)
140+ while i < 2;
141+ }
142+
143+ for_loop: (x: double, y: double) -> (r: double) = {
144+ v: std::vector<double> = ();
145+
146+ v.push_back(x);
147+ v.push_back(y);
148+
149+ r = 0.0;
150+ for v
151+ do (t)
152+ {
153+ r = r + t;
154+ }
155+ }
66156}
67157
68158write_output: (func: std::string, x: double, x_d: ad_type, y: double, y_d: ad_type, ret) = {
@@ -96,15 +186,15 @@ main: () = {
96186 write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_d(x, x_d, y, y_d));
97187 write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_d(x, x_d, y, y_d));
98188 write_output("sin(x - y)", x, x_d, y, y_d, ad_test::sin_call_d(x, x_d, y, y_d));
99- // write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
100- // write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_d(x, x_d, y, y_d));
101- // write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_d(x, x_d, y, y_d));
102- // write_output("intermediate var", x, x_d, y, y_d, ad_test::intermediate_var_d(x, x_d, y, y_d));
103- // write_output("intermediate passive var", x, x_d, y, y_d, ad_test::intermediate_passive_var_d(x, x_d, y, y_d));
104- // write_output("intermediate untyped", x, x_d, y, y_d, ad_test::intermediate_untyped_d(x, x_d, y, y_d));
105- // write_output("intermediate default init", x, x_d, y, y_d, ad_test::intermediate_default_init_d(x, x_d, y, y_d));
106- // write_output("intermediate no init", x, x_d, y, y_d, ad_test::intermediate_no_init_d(x, x_d, y, y_d));
107- // write_output("while loop", x, x_d, y, y_d, ad_test::while_loop_d(x, x_d, y, y_d));
108- // write_output("do while loop", x, x_d, y, y_d, ad_test::do_while_loop_d(x, x_d, y, y_d));
109- // write_output("for loop", x, x_d, y, y_d, ad_test::for_loop_d(x, x_d, y, y_d));
189+ write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_d(x, x_d, y, y_d));
190+ write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_d(x, x_d, y, y_d));
191+ write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_d(x, x_d, y, y_d));
192+ write_output("intermediate var", x, x_d, y, y_d, ad_test::intermediate_var_d(x, x_d, y, y_d));
193+ write_output("intermediate passive var", x, x_d, y, y_d, ad_test::intermediate_passive_var_d(x, x_d, y, y_d));
194+ write_output("intermediate untyped", x, x_d, y, y_d, ad_test::intermediate_untyped_d(x, x_d, y, y_d));
195+ write_output("intermediate default init", x, x_d, y, y_d, ad_test::intermediate_default_init_d(x, x_d, y, y_d));
196+ write_output("intermediate no init", x, x_d, y, y_d, ad_test::intermediate_no_init_d(x, x_d, y, y_d));
197+ write_output("while loop", x, x_d, y, y_d, ad_test::while_loop_d(x, x_d, y, y_d));
198+ write_output("do while loop", x, x_d, y, y_d, ad_test::do_while_loop_d(x, x_d, y, y_d));
199+ write_output("for loop", x, x_d, y, y_d, ad_test::for_loop_d(x, x_d, y, y_d));
110200}
0 commit comments