Skip to content

Commit 965af06

Browse files
committed
Handling of expresson terms.
1 parent 97470ba commit 965af06

File tree

6 files changed

+500
-430
lines changed

6 files changed

+500
-430
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ ad_test: @autodiff @print type = {
4444
mul_add: (x: double, y: double) -> (r: double) = {
4545
r = x * (x + y);
4646
}
47+
48+
add_mul: (x: double, y: double) -> (r: double) = {
49+
r = x + x * y;
50+
}
4751
}
4852

4953
write_output: (func: std::string, x: double, x_d: double, y: double, y_d: double, ret) = {
@@ -68,4 +72,5 @@ main: () = {
6872
write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_diff(x, x_d, y, y_d));
6973
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_diff(x, x_d, y, y_d));
7074
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_diff(x, x_d, y, y_d));
75+
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_diff(x, x_d, y, y_d));
7176
}

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ diff(x / y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r
99
diff(x / y / y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 0.222222, r_d = -0.185185)
1010
diff(x * y / x) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 3.000000, r_d = 2.000000)
1111
diff(x * (x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)
12+
diff(x + x * y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 8.000000, r_d = 8.000000)

regression-tests/test-results/pure2-autodiff.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ using mul_add_ret = double;
7373

7474
#line 44 "pure2-autodiff.cpp2"
7575
public: [[nodiscard]] static auto mul_add(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> mul_add_ret;
76+
using add_mul_ret = double;
77+
78+
79+
#line 48 "pure2-autodiff.cpp2"
80+
public: [[nodiscard]] static auto add_mul(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> add_mul_ret;
7681
struct add_1_diff_ret { double r; double r_d; };
7782

7883

@@ -118,17 +123,21 @@ struct mul_add_diff_ret { double r; double r_d; };
118123

119124
public: [[nodiscard]] static auto mul_add_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> mul_add_diff_ret;
120125

126+
struct add_mul_diff_ret { double r; double r_d; };
127+
128+
public: [[nodiscard]] static auto add_mul_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> add_mul_diff_ret;
129+
121130
public: ad_test() = default;
122131
public: ad_test(ad_test const&) = delete; /* No 'that' constructor, suppress copy */
123132
public: auto operator=(ad_test const&) -> void = delete;
124133

125134

126-
#line 47 "pure2-autodiff.cpp2"
135+
#line 51 "pure2-autodiff.cpp2"
127136
};
128137

129138
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void;
130139

131-
#line 53 "pure2-autodiff.cpp2"
140+
#line 57 "pure2-autodiff.cpp2"
132141
auto main() -> int;
133142

134143
//=== Cpp2 function definitions =================================================
@@ -212,6 +221,13 @@ auto main() -> int;
212221
r.construct(x * (x + y));
213222
return std::move(r.value()); }
214223

224+
#line 48 "pure2-autodiff.cpp2"
225+
[[nodiscard]] auto ad_test::add_mul(cpp2::impl::in<double> x, cpp2::impl::in<double> y) -> add_mul_ret{
226+
cpp2::impl::deferred_init<double> r;
227+
#line 49 "pure2-autodiff.cpp2"
228+
r.construct(x + x * y);
229+
return std::move(r.value()); }
230+
215231
[[nodiscard]] auto ad_test::add_1_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> add_1_diff_ret{
216232
double r {0.0};
217233
double r_d {0.0};r_d = x_d + y_d;r = x + y;
@@ -275,13 +291,20 @@ auto temp_1_d {x_d + y_d};
275291
auto temp_1 {x + y}; r_d = x * cpp2::move(temp_1_d) + temp_1 * x_d;r = x * cpp2::move(temp_1);
276292
return { std::move(r), std::move(r_d) };
277293
}
294+
[[nodiscard]] auto ad_test::add_mul_diff(cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d) -> add_mul_diff_ret{
295+
double r {0.0};
296+
double r_d {0.0};
297+
auto temp_1_d {x * y_d + y * x_d};
298+
auto temp_1 {x * y}; r_d = x_d + cpp2::move(temp_1_d);r = x + cpp2::move(temp_1);
299+
return { std::move(r), std::move(r_d) };
300+
}
278301

279-
#line 49 "pure2-autodiff.cpp2"
302+
#line 53 "pure2-autodiff.cpp2"
280303
auto write_output(cpp2::impl::in<std::string> func, cpp2::impl::in<double> x, cpp2::impl::in<double> x_d, cpp2::impl::in<double> y, cpp2::impl::in<double> y_d, auto const& ret) -> void{
281304
std::cout << "diff(" + cpp2::to_string(func) + ") at (x = " + cpp2::to_string(x) + ", x_d = " + cpp2::to_string(x_d) + ", y = " + cpp2::to_string(y) + ", y_d = " + cpp2::to_string(y_d) + ") = (r = " + cpp2::to_string(ret.r) + ", r_d = " + cpp2::to_string(ret.r_d) + ")" << std::endl;
282305
}
283306

284-
#line 53 "pure2-autodiff.cpp2"
307+
#line 57 "pure2-autodiff.cpp2"
285308
auto main() -> int{
286309

287310
double x {2.0};
@@ -299,6 +322,7 @@ auto main() -> int{
299322
write_output("x / y", x, x_d, y, y_d, ad_test::div_1_diff(x, x_d, y, y_d));
300323
write_output("x / y / y", x, x_d, y, y_d, ad_test::div_2_diff(x, x_d, y, y_d));
301324
write_output("x * y / x", x, x_d, y, y_d, ad_test::mul_div_2_diff(x, x_d, y, y_d));
302-
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_diff(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
325+
write_output("x * (x + y)", x, x_d, y, y_d, ad_test::mul_add_diff(x, x_d, y, y_d));
326+
write_output("x + x * y", x, x_d, y, y_d, ad_test::add_mul_diff(cpp2::move(x), cpp2::move(x_d), cpp2::move(y), cpp2::move(y_d)));
303327
}
304328

regression-tests/test-results/pure2-autodiff.cpp2.output

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ ad_test:/* @autodiff @print */ type =
101101
return;
102102
}
103103

104+
add_mul:(
105+
in x: double,
106+
in y: double,
107+
) -> (out r: double, ) =
108+
{
109+
r = x + x * y;
110+
return;
111+
}
112+
104113
add_1_diff:(
105114
in x: double,
106115
in x_d: double,
@@ -273,6 +282,23 @@ ad_test:/* @autodiff @print */ type =
273282
r = x * temp_1;
274283
return;
275284
}
285+
286+
add_mul_diff:(
287+
in x: double,
288+
in x_d: double,
289+
in y: double,
290+
in y_d: double,
291+
) -> (
292+
out r: double = 0.0,
293+
out r_d: double = 0.0,
294+
) =
295+
{
296+
temp_1_d: _ = x * y_d + y * x_d;
297+
temp_1: _ = x * y;
298+
r_d = x_d + temp_1_d;
299+
r = x + temp_1;
300+
return;
301+
}
276302
}
277303
ok (all Cpp2, passes safety checks)
278304

0 commit comments

Comments
 (0)