@@ -106,6 +106,11 @@ using if_else_branch_ret = double;
106106
107107#line 81 "pure2-autodiff.cpp2"
108108 public: [[nodiscard]] static auto direct_return (cpp2::impl::in<double > x, cpp2::impl::in<double > y) -> double;
109+ using intermediate_var_ret = double ;
110+
111+
112+ #line 85 "pure2-autodiff.cpp2"
113+ public: [[nodiscard]] static auto intermediate_var (cpp2::impl::in<double > x, cpp2::impl::in<double > y) -> intermediate_var_ret;
109114struct add_1_diff_ret { double r; double r_d; };
110115
111116
@@ -179,17 +184,21 @@ struct direct_return_diff_ret { double r; double r_d; };
179184
180185public: [[nodiscard]] static auto direct_return_diff (cpp2::impl::in<double > x, cpp2::impl::in<double > x_d, cpp2::impl::in<double > y, cpp2::impl::in<double > y_d) -> direct_return_diff_ret;
181186
187+ struct intermediate_var_diff_ret { double r; double r_d; };
188+
189+ public: [[nodiscard]] static auto intermediate_var_diff (cpp2::impl::in<double > x, cpp2::impl::in<double > x_d, cpp2::impl::in<double > y, cpp2::impl::in<double > y_d) -> intermediate_var_diff_ret;
190+
182191 public: ad_test() = default ;
183192 public: ad_test(ad_test const &) = delete ; /* No 'that' constructor, suppress copy */
184193 public: auto operator =(ad_test const &) -> void = delete ;
185194
186195
187- #line 84 "pure2-autodiff.cpp2"
196+ #line 91 "pure2-autodiff.cpp2"
188197};
189198
190199auto write_output (cpp2::impl::in<std::string> func, cpp2::impl::in<double > x, cpp2::impl::in<double > x_d, cpp2::impl::in<double > y, cpp2::impl::in<double > y_d, auto const & ret) -> void;
191200
192- #line 90 "pure2-autodiff.cpp2"
201+ #line 97 "pure2-autodiff.cpp2"
193202auto main () -> int;
194203
195204// === Cpp2 function definitions =================================================
@@ -329,6 +338,16 @@ auto main() -> int;
329338 return x + y;
330339 }
331340
341+ #line 85 "pure2-autodiff.cpp2"
342+ [[nodiscard]] auto ad_test::intermediate_var (cpp2::impl::in<double > x, cpp2::impl::in<double > y) -> intermediate_var_ret{
343+ cpp2::impl::deferred_init<double > r;
344+ #line 86 "pure2-autodiff.cpp2"
345+ double t {}; // TODO: change to x initializer when we have access to the initializer expression.
346+ t = x + y;
347+
348+ r.construct (cpp2::move (t));
349+ return std::move (r.value ()); }
350+
332351 [[nodiscard]] auto ad_test::add_1_diff (cpp2::impl::in<double > x, cpp2::impl::in<double > x_d, cpp2::impl::in<double > y, cpp2::impl::in<double > y_d) -> add_1_diff_ret{
333352 double r {0.0 };
334353 double r_d {0.0 };r_d = x_d + y_d;r = x + y;return { std::move (r), std::move (r_d) };
@@ -447,13 +466,21 @@ auto temp_1 {x - y}; r_d = cos(temp_1) * cpp2::move(temp_1_d);
447466 double r_d {};r_d = x_d + y_d;r = x + y;
448467 return { std::move (r), std::move (r_d) };
449468 }
469+ [[nodiscard]] auto ad_test::intermediate_var_diff (cpp2::impl::in<double > x, cpp2::impl::in<double > x_d, cpp2::impl::in<double > y, cpp2::impl::in<double > y_d) -> intermediate_var_diff_ret{
470+ double r {0.0 };
471+ double r_d {0.0 };
472+ double t_d {};
450473
451- #line 86 "pure2-autodiff.cpp2"
474+ double t {};
475+ t_d = x_d + y_d;t = x + y;r_d = cpp2::move (t_d);r = cpp2::move (t);return { std::move (r), std::move (r_d) };
476+ }
477+
478+ #line 93 "pure2-autodiff.cpp2"
452479auto write_output (cpp2::impl::in<std::string> func, cpp2::impl::in<double > x, cpp2::impl::in<double > x_d, cpp2::impl::in<double > y, cpp2::impl::in<double > y_d, auto const & ret) -> void{
453480 std::cout << " diff(" + cpp2::to_string (func) + " ) at (x = " + cpp2::to_string (x) + " , x_d = " + cpp2::to_string (x_d) + " , y = " + cpp2::to_string (y) + " , y_d = " + cpp2::to_string (y_d) + " ) = (r = " + cpp2::to_string (ret.r ) + " , r_d = " + cpp2::to_string (ret.r_d ) + " )" << std::endl;
454481}
455482
456- #line 90 "pure2-autodiff.cpp2"
483+ #line 97 "pure2-autodiff.cpp2"
457484auto main () -> int{
458485
459486 double x {2.0 };
@@ -477,6 +504,7 @@ auto main() -> int{
477504 write_output (" sin(x + y)" , x, x_d, y, y_d, ad_test::sin_call_diff (x, x_d, y, y_d));
478505 write_output (" if branch" , x, x_d, y, y_d, ad_test::if_branch_diff (x, x_d, y, y_d));
479506 write_output (" if else branch" , x, x_d, y, y_d, ad_test::if_else_branch_diff (x, x_d, y, y_d));
480- write_output (" direct return" , x, x_d, y, y_d, ad_test::direct_return_diff (cpp2::move (x), cpp2::move (x_d), cpp2::move (y), cpp2::move (y_d)));
507+ write_output (" direct return" , x, x_d, y, y_d, ad_test::direct_return_diff (x, x_d, y, y_d));
508+ write_output (" intermediate var" , x, x_d, y, y_d, ad_test::intermediate_var_diff (cpp2::move (x), cpp2::move (x_d), cpp2::move (y), cpp2::move (y_d)));
481509}
482510
0 commit comments