Skip to content

Commit dad71ef

Browse files
committed
Added handling of direct return.
1 parent 00be8ff commit dad71ef

File tree

6 files changed

+466
-394
lines changed

6 files changed

+466
-394
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ ad_test: @autodiff @print type = {
7777
r = x;
7878
}
7979
}
80+
81+
direct_return: (x: double, y: double) -> double = {
82+
return x + y;
83+
}
8084
}
8185

8286
write_output: (func: std::string, x: double, x_d: double, y: double, y_d: double, ret) = {
@@ -106,4 +110,5 @@ main: () = {
106110
write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_diff(x, x_d, y, y_d));
107111
write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_diff(x, x_d, y, y_d));
108112
write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_diff(x, x_d, y, y_d));
113+
write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_diff(x, x_d, y, y_d));
109114
}

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ diff(x * func(x, y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000
1414
diff(sin(x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = -0.841471, r_d = -0.540302)
1515
diff(if branch) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 2.000000, r_d = 1.000000)
1616
diff(if else branch) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 2.000000, r_d = 1.000000)
17+
diff(direct return) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 5.000000, r_d = 3.000000)

regression-tests/test-results/pure2-autodiff.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ using if_else_branch_ret = double;
103103

104104
#line 72 "pure2-autodiff.cpp2"
105105
public: [[nodiscard]] static auto if_else_branch(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> if_else_branch_ret;
106+
107+
#line 81 "pure2-autodiff.cpp2"
108+
public: [[nodiscard]] static auto direct_return(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> double;
106109
struct add_1_diff_ret { double r; double r_d; };
107110

108111

@@ -172,17 +175,21 @@ struct if_else_branch_diff_ret { double r; double r_d; };
172175

173176
public: [[nodiscard]] static auto if_else_branch_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> if_else_branch_diff_ret;
174177

178+
struct direct_return_diff_ret { double r; double r_d; };
179+
180+
public: [[nodiscard]] static auto direct_return_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> direct_return_diff_ret;
181+
175182
public: ad_test() = default;
176183
public: ad_test(ad_test const&) = delete; /* No 'that' constructor, suppress copy */
177184
public: auto operator=(ad_test const&) -> void = delete;
178185

179186

180-
#line 80 "pure2-autodiff.cpp2"
187+
#line 84 "pure2-autodiff.cpp2"
181188
};
182189

183190
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void;
184191

185-
#line 86 "pure2-autodiff.cpp2"
192+
#line 90 "pure2-autodiff.cpp2"
186193
auto main() -> int;
187194

188195
//=== Cpp2 function definitions =================================================
@@ -317,6 +324,11 @@ auto main() -> int;
317324
}return std::move(r.value());
318325
}
319326

327+
#line 81 "pure2-autodiff.cpp2"
328+
[[nodiscard]] auto ad_test::direct_return(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> double{
329+
return x + y;
330+
}
331+
320332
[[nodiscard]] auto ad_test::add_1_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> add_1_diff_ret{
321333
double r {0.0};
322334
double r_d {0.0};r_d = x_d + y_d;r = x + y;return { std::move(r), std::move(r_d) };
@@ -430,12 +442,18 @@ auto temp_1 {x - y}; r_d = cos(temp_1) * cpp2::move(temp_1_d);
430442
return { std::move(r), std::move(r_d) };
431443
}
432444

433-
#line 82 "pure2-autodiff.cpp2"
445+
[[nodiscard]] auto ad_test::direct_return_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> direct_return_diff_ret{
446+
double r {};
447+
double r_d {};r_d = x_d + y_d;r = x + y;
448+
return { std::move(r), std::move(r_d) };
449+
}
450+
451+
#line 86 "pure2-autodiff.cpp2"
434452
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void{
435453
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", x_d = " + cpp2::to_string(x_d) + ", y = " + cpp2::to_string(y) + ", y_d = " + cpp2::to_string(y_d) + ") = (r = " + cpp2::to_string(ret.r) + ", r_d = " + cpp2::to_string(ret.r_d) + ")" << std::endl;
436454
}
437455

438-
#line 86 "pure2-autodiff.cpp2"
456+
#line 90 "pure2-autodiff.cpp2"
439457
auto main() -> int{
440458

441459
double x {2.0};
@@ -458,6 +476,7 @@ auto main() -> int{
458476
write_output("x * func(x, y)", x, x_d, y, y_d, ad_test::func_call_diff(x, x_d, y, y_d));
459477
write_output("sin(x + y)", x, x_d, y, y_d, ad_test::sin_call_diff(x, x_d, y, y_d));
460478
write_output("if branch", x, x_d, y, y_d, ad_test::if_branch_diff(x, x_d, y, y_d));
461-
write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_diff(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
479+
write_output("if else branch", x, x_d, y, y_d, ad_test::if_else_branch_diff(x, x_d, y, y_d));
480+
write_output("direct return", x, x_d, y, y_d, ad_test::direct_return_diff(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
462481
}
463482

regression-tests/test-results/pure2-autodiff.cpp2.output

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ ad_test:/* @autodiff @print */ type =
166166
return;
167167
}
168168

169+
direct_return:(
170+
in x: double,
171+
in y: double,
172+
) -> move double =
173+
{
174+
return x + y;
175+
}
176+
169177
add_1_diff:(
170178
in x: double,
171179
in x_d: double,
@@ -451,6 +459,21 @@ ad_test:/* @autodiff @print */ type =
451459
}
452460
return;
453461
}
462+
463+
direct_return_diff:(
464+
in x: double,
465+
in x_d: double,
466+
in y: double,
467+
in y_d: double,
468+
) -> (
469+
out r: double = (),
470+
out r_d: double = (),
471+
) =
472+
{
473+
r_d = x_d + y_d;
474+
r = x + y;
475+
return;
476+
}
454477
}
455478
ok (all Cpp2, passes safety checks)
456479

0 commit comments

Comments
 (0)