Skip to content

Commit beb3629

Browse files
committed
Reverse differentiation of additive expressions.
1 parent e48a43f commit beb3629

File tree

2 files changed

+87
-50
lines changed

2 files changed

+87
-50
lines changed

include/cpp2taylor.h2

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ taylor: <R, dim: int> type = {
5454
return v[i - 1];
5555
}
5656

57+
// Overload for reverse AD.
58+
operator+=: (inout this, o: taylor) -> forward_ref _ = {
59+
this = this + o;
60+
return this;
61+
}
62+
63+
// Overload for reverse AD.
64+
operator-=: (inout this, o: taylor) -> forward_ref _ = {
65+
this = this - o;
66+
return this;
67+
}
68+
5769
// Overload for simple handling of connected adds.
5870
operator+: (this, o: taylor) -> taylor = {
5971
return add(o, 0.0, 0.0); // Primal values are not required.

source/reflect.h2

Lines changed: 75 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4383,19 +4383,22 @@ autodiff_context: type = {
43834383
autodiff_diff_code: type = {
43844384
public ctx: *autodiff_context;
43854385

4386-
public fwd: std::string = "";
4387-
public rws: std::string = "";
4386+
public fwd : std::string = "";
4387+
public rws_primal : std::string = "";
4388+
public rws_backprop: std::string = "";
43884389

43894390
operator=:(out this, ctx_: *autodiff_context) = {
43904391
ctx = ctx_;
43914392
}
43924393

4393-
add_forward: (inout this, v: std::string) = { if ctx*.is_forward() { fwd += v; }}
4394-
add_reverse: (inout this, v: std::string) = { if ctx*.is_reverse() { rws += v; }}
4394+
add_forward : (inout this, v: std::string) = { if ctx*.is_forward() { fwd += v; }}
4395+
add_reverse_primal : (inout this, v: std::string) = { if ctx*.is_reverse() { rws_primal += v; }}
4396+
add_reverse_backprop: (inout this, v: std::string) = { if ctx*.is_reverse() { rws_backprop += v; }}
43954397

43964398
reset: (inout this) = {
4397-
fwd = "";
4398-
rws = "";
4399+
fwd = "";
4400+
rws_primal = "";
4401+
rws_backprop = "";
43994402
}
44004403

44014404
// Temporary: TODO: remove when everything has been adapted to primal, fwd, rws pushes.
@@ -4437,7 +4440,9 @@ autodiff_handler_base: type = {
44374440

44384441
// Temporary: TODO: remove when everything has been adapted to primal, fwd, rws pushes.
44394442
append: (inout this, in_ref o: autodiff_handler_base) = {
4440-
diff += o.diff.fwd;
4443+
diff.fwd += o.diff.fwd;
4444+
diff.rws_primal += o.diff.rws_primal;
4445+
diff.rws_backprop = o.diff.rws_backprop + diff.rws_backprop;
44414446
}
44424447
}
44434448

@@ -4449,28 +4454,35 @@ autodiff_expression_handler: type = {
44494454

44504455
public primal_expr: std::string = "";
44514456
public fwd_expr : std::string = "";
4457+
public rws_expr : std::string = "";
44524458

44534459
operator=: (out this, ctx_: *autodiff_context) = {
44544460
autodiff_handler_base = (ctx_);
44554461
}
44564462

4457-
add_suffix_if_not_wildcard: (this, lhs: std::string) -> std::string = {
4463+
add_suffix_if_not_wildcard: (this, lhs: std::string, suffix: std::string) -> std::string = {
44584464
if "_" == lhs {
44594465
return lhs;
44604466
}
44614467
else {
4462-
return lhs + ctx*.fwd_suffix;
4468+
return lhs + suffix;
44634469
}
44644470
}
44654471

4466-
gen_assignment: (inout this, lhs: std::string, lhs_d: std::string, rhs: std::string, rhs_d: std::string) = {
4467-
diff += "(lhs_d)$ = (rhs_d)$;\n";
4468-
diff += "(lhs)$ = (rhs)$;\n";
4472+
gen_assignment: (inout this, lhs: std::string, lhs_d: std::string, lhs_b: std::string, rhs: std::string, rhs_d: std::string, rhs_b: std::string) = {
4473+
diff.add_forward("(lhs_d)$ = (rhs_d)$;\n");
4474+
diff.add_forward("(lhs)$ = (rhs)$;\n");
4475+
4476+
if ctx*.is_taylor() {
4477+
diff.add_reverse_primal("(lhs_d)$ = (rhs_d)$;\n");
4478+
}
4479+
diff.add_reverse_primal("(lhs)$ = (rhs)$;\n");
4480+
diff.add_reverse_backprop(string_util::replace_all(rhs_b, "_rb_", lhs_b));
44694481
}
4470-
gen_assignment: (inout this, lhs: std::string, lhs_d: std::string)
4471-
= gen_assignment(lhs, lhs_d, primal_expr, fwd_expr);
4482+
gen_assignment: (inout this, lhs: std::string, lhs_d: std::string, lhs_b: std::string)
4483+
= gen_assignment(lhs, lhs_d, lhs_b, primal_expr, fwd_expr, rws_expr);
44724484
gen_assignment: (inout this, lhs: std::string)
4473-
= gen_assignment(lhs, add_suffix_if_not_wildcard(lhs), primal_expr, fwd_expr);
4485+
= gen_assignment(lhs, add_suffix_if_not_wildcard(lhs, ctx*.fwd_suffix), add_suffix_if_not_wildcard(lhs, ctx*.rws_suffix), primal_expr, fwd_expr, rws_expr);
44744486

44754487

44764488
gen_declaration: (inout this, lhs: std::string, lhs_d: std::string, rhs: std::string, rhs_d: std::string, type: std::string, type_d: std::string) = {
@@ -4486,43 +4498,46 @@ autodiff_expression_handler: type = {
44864498

44874499

44884500

4489-
primal_fwd_name: @struct type = {
4501+
primal_fwd_rws_name: @struct type = {
44904502
primal: std::string = "";
44914503
fwd : std::string = "";
4504+
rws : std::string = "";
44924505
}
44934506

4494-
handle_expression_list: (inout this, list: meta::expression_list) -> std::vector<primal_fwd_name> = {
4495-
args : std::vector<primal_fwd_name> = ();
4507+
handle_expression_list: (inout this, list: meta::expression_list) -> std::vector<primal_fwd_rws_name> = {
4508+
args : std::vector<primal_fwd_rws_name> = ();
44964509
for list.get_expressions() do (expr) {
44974510
args.push_back(handle_expression_term(expr));
44984511
}
44994512

45004513
return args;
45014514
}
45024515

4503-
handle_expression_term :(inout this, term) -> primal_fwd_name = {
4516+
handle_expression_term :(inout this, term) -> primal_fwd_rws_name = {
45044517
if term.is_identifier() {
45054518
primal := term.to_string();
45064519
fwd := primal + ctx*.fwd_suffix;
4520+
rws := primal + ctx*.rws_suffix;
45074521

45084522
decl := ctx*.lookup_variable_declaration(primal);
45094523
if decl.is_member {
45104524
fwd = "this(ctx*.fwd_suffix)$." + fwd;
4525+
rws = "this(ctx*.rws_suffix)$." + rws;
45114526
}
4512-
return (primal, fwd);
4527+
return (primal, fwd, rws);
45134528
}
45144529
else if term.is_expression_list() {
45154530
exprs := term..as_expression_list()..get_expressions();
45164531
if exprs.ssize() != 1 {
45174532
term.error("Can not handle multiple expressions. (term.to_string())");
4518-
return ("error", "");
4533+
return ("error", "", "");
45194534
}
45204535
expr := exprs[0];
45214536
bin_expr := expr..as_assignment_expression();
45224537

45234538
if bin_expr.terms_size() != 0 {
45244539
term.error("Can not handle assign expr inside of expression. (expr.to_string())$");
4525-
return ("error", "");
4540+
return ("error", "", "");
45264541
}
45274542

45284543
ad : autodiff_expression_handler = (ctx);
@@ -4531,7 +4546,7 @@ autodiff_expression_handler: type = {
45314546
ad.gen_declaration(t, "double"); // TODO: get type of expression
45324547
append(ad);
45334548

4534-
r : primal_fwd_name = (t, t + ctx*.fwd_suffix); // TODO: Check why on return (t, t + ctx*.fwd_suffix) the primal is initialized empty. Probably because of the move(t)
4549+
r : primal_fwd_rws_name = (t, t + ctx*.fwd_suffix, t + ctx*.rws_suffix); // TODO: Check why on return (t, t + ctx*.fwd_suffix) the primal is initialized empty. Probably because of the move(t)
45354550
_ = t;
45364551
return r;
45374552
}
@@ -4545,7 +4560,7 @@ autodiff_expression_handler: type = {
45454560
ad.gen_declaration(t, "double"); // TODO: get type of expression
45464561
append(ad);
45474562

4548-
r : primal_fwd_name = (t, t + ctx*.fwd_suffix); // TODO: Check why on return (t, t + ctx*.fwd_suffix) the primal is initialized empty. Probably because of the move(t)
4563+
r : primal_fwd_rws_name = (t, t + ctx*.fwd_suffix, t + ctx*.rws_suffix); // TODO: Check why on return (t, t + ctx*.fwd_suffix) the primal is initialized empty. Probably because of the move(t)
45494564
_ = t;
45504565
return r;
45514566
}
@@ -4577,7 +4592,7 @@ autodiff_expression_handler: type = {
45774592
object : std::string = "";
45784593
object_d : std::string = "";
45794594
function_name : std::string = "";
4580-
args : std::vector<primal_fwd_name> = ();
4595+
args : std::vector<primal_fwd_rws_name> = ();
45814596

45824597
primary := postfix.get_primary_expression();
45834598

@@ -4695,7 +4710,7 @@ autodiff_expression_handler: type = {
46954710
// TODO: Add function to list of functions/objects for differentiation for the no return case.
46964711
}
46974712

4698-
handle_special_function: (inout this, object: std::string, object_d: std::string, function_name: std::string, args: std::vector<primal_fwd_name>) -> bool = {
4713+
handle_special_function: (inout this, object: std::string, object_d: std::string, function_name: std::string, args: std::vector<primal_fwd_rws_name>) -> bool = {
46994714

47004715
r := ctx*.lookup_special_function_handling(function_name, args.ssize(), !object.empty());
47014716

@@ -4778,24 +4793,29 @@ autodiff_expression_handler: type = {
47784793
terms := binexpr.get_terms();
47794794

47804795
first := true;
4796+
op : std::string = "+";
47814797
fwd : std::string = "";
4798+
rws : std::string = "";
47824799
primal: std::string = "";
47834800
for terms do (term) {
47844801
if !first {
4785-
op := term.get_op().to_string();
4802+
op = term.get_op().to_string();
47864803
fwd += " (op)$ ";
47874804
primal += " (op)$ ";
4805+
47884806
}
47894807

47904808
var := handle_expression_term(term.get_term());
47914809
fwd += var.fwd;
4810+
rws += "(var.rws)$ (op)$= _rb_;\n";
47924811
primal += var.primal;
47934812

47944813
first = false;
47954814
}
47964815

47974816
primal_expr = primal;
47984817
fwd_expr = fwd;
4818+
rws_expr = rws;
47994819
}
48004820

48014821
traverse: (override inout this, binexpr: meta::multiplicative_expression) = {
@@ -4919,17 +4939,20 @@ autodiff_expression_handler: type = {
49194939
{
49204940
if primary.is_identifier() {
49214941
primal_expr = primary.to_string();
4922-
fwd_expr = add_suffix_if_not_wildcard(primal_expr);
4942+
fwd_expr = add_suffix_if_not_wildcard(primal_expr, ctx*.fwd_suffix);
4943+
rws_expr = add_suffix_if_not_wildcard(primal_expr, ctx*.rws_suffix);
49234944

49244945
decl := ctx*.lookup_variable_declaration(primal_expr);
49254946
if decl.is_member {
49264947
fwd_expr = "this(ctx*.fwd_suffix)$." + fwd_expr;
4948+
rws_expr = "this(ctx*.rws_suffix)$." + rws_expr;
49274949
}
49284950
}
49294951
else if primary.is_expression_list() {
49304952
if primary.as_expression_list().is_empty() {
49314953
primal_expr = "()";
49324954
fwd_expr = "()";
4955+
rws_expr = "()"; // TODO: Check for reverse
49334956
}
49344957
else {
49354958
primary.error("AD: Do not know how to handle non empty expression list inside of primary_expression: (primary.to_string())$");
@@ -4938,6 +4961,7 @@ autodiff_expression_handler: type = {
49384961
else if primary.is_literal() {
49394962
primal_expr = primary.to_string();
49404963
fwd_expr = "()";
4964+
rws_expr = "()"; // TODO: Check for reverse
49414965
}
49424966
else if primary.is_declaration() {
49434967
primary.error("AD: Do not know how to handle declaration inside of primary_expression: (primary.to_string())$");
@@ -5114,7 +5138,7 @@ autodiff_stmt_handler: type = {
51145138

51155139
h: autodiff_expression_handler = (ctx);
51165140
h.pre_traverse(assignment_terms[1].get_term());
5117-
h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr);
5141+
h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, h_lhs.rws_expr);
51185142
append(h);
51195143
}
51205144

@@ -5234,7 +5258,7 @@ autodiff_declaration_handler: type = {
52345258
ctx*.enter_function();
52355259

52365260
diff.add_forward(" (f.name())$(ctx*.fwd_suffix)$: (");
5237-
diff.add_reverse(" (f.name())$(ctx*.rws_suffix)$: (");
5261+
diff.add_reverse_primal(" (f.name())$(ctx*.rws_suffix)$: (");
52385262

52395263
// 1. Generate the modified signature
52405264
// a) Parameters
@@ -5252,22 +5276,22 @@ autodiff_declaration_handler: type = {
52525276
diff.add_forward("(fwd_pass_style)$ (name)$, ");
52535277
diff.add_forward("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$: (fwd_ad_type)$, ");
52545278

5255-
diff.add_reverse("(fwd_pass_style)$ (name)$, ");
5279+
diff.add_reverse_primal("(fwd_pass_style)$ (name)$, ");
52565280
if ctx*.is_taylor() { // Add forward type for higher order
5257-
diff.add_reverse("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$: (fwd_ad_type)$, ");
5281+
diff.add_reverse_primal("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$: (fwd_ad_type)$, ");
52585282
}
5259-
diff.add_reverse("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$: (rws_ad_type)$, ");
5283+
diff.add_reverse_primal("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$: (rws_ad_type)$, ");
52605284
}
52615285
else {
52625286
type := param.get_declaration().type();
52635287
diff.add_forward("(fwd_pass_style)$ (name)$ : (type)$, ");
52645288
diff.add_forward("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$, ");
52655289

5266-
diff.add_reverse("(fwd_pass_style)$ (name)$ : (type)$, ");
5290+
diff.add_reverse_primal("(fwd_pass_style)$ (name)$ : (type)$, ");
52675291
if ctx*.is_taylor() {
5268-
diff.add_reverse("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$, ");
5292+
diff.add_reverse_primal("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$, ");
52695293
}
5270-
diff.add_reverse("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$ : (ctx*.get_rws_ad_type(type))$, ");
5294+
diff.add_reverse_primal("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$ : (ctx*.get_rws_ad_type(type))$, ");
52715295

52725296

52735297
ctx*.add_variable_declaration(name, type);
@@ -5280,10 +5304,10 @@ autodiff_declaration_handler: type = {
52805304
// TODO: check if name "r" is available. (Also needs inspection of functions at call sides.)
52815305
if f.has_deduced_return_type() {
52825306
// TODO: Take care of initialization order error.
5283-
diff.add_reverse("inout r(ctx*.rws_suffix)$, ");
5307+
diff.add_reverse_primal("inout r(ctx*.rws_suffix)$, ");
52845308
}
52855309
else {
5286-
diff.add_reverse("inout r(ctx*.rws_suffix)$: (ctx*.get_rws_ad_type(f.get_unnamed_return_type()))$, ");
5310+
diff.add_reverse_primal("inout r(ctx*.rws_suffix)$: (ctx*.get_rws_ad_type(f.get_unnamed_return_type()))$, ");
52875311
}
52885312
}
52895313
else {
@@ -5292,12 +5316,12 @@ autodiff_declaration_handler: type = {
52925316
type := param.get_declaration().type();
52935317

52945318
rws_pass_style := to_string_view(ctx*.get_reverse_passing_style(param.get_passing_style()));
5295-
diff.add_reverse("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$ : (ctx*.get_rws_ad_type(type))$ , ");
5319+
diff.add_reverse_primal("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$ : (ctx*.get_rws_ad_type(type))$ , ");
52965320
}
52975321
}
52985322

52995323
diff.add_forward(") -> (");
5300-
diff.add_reverse(") -> (");
5324+
diff.add_reverse_primal(") -> (");
53015325

53025326
// c) Returns
53035327

@@ -5306,16 +5330,16 @@ autodiff_declaration_handler: type = {
53065330
if f.has_deduced_return_type() {
53075331
// TODO: Take care of initialization order error.
53085332
diff.add_forward("r, r(ctx*.fwd_suffix)$, ");
5309-
diff.add_reverse("r, ");
5333+
diff.add_reverse_primal("r, ");
53105334
if ctx*.is_taylor() {
5311-
diff.add_reverse("r(ctx*.fwd_suffix)$,");
5335+
diff.add_reverse_primal("r(ctx*.fwd_suffix)$,");
53125336
}
53135337
}
53145338
else {
53155339
diff.add_forward("r: (f.get_unnamed_return_type())$ = (), r(ctx*.fwd_suffix)$: (ctx*.get_fwd_ad_type(f.get_unnamed_return_type()))$ = (), ");
5316-
diff.add_reverse("r: (f.get_unnamed_return_type())$ = (), ");
5340+
diff.add_reverse_primal("r: (f.get_unnamed_return_type())$ = (), ");
53175341
if ctx*.is_taylor() {
5318-
diff.add_reverse("r(ctx*.fwd_suffix)$: (ctx*.get_fwd_ad_type(f.get_unnamed_return_type()))$ = (), ");
5342+
diff.add_reverse_primal("r(ctx*.fwd_suffix)$: (ctx*.get_fwd_ad_type(f.get_unnamed_return_type()))$ = (), ");
53195343
}
53205344
}
53215345
}
@@ -5329,17 +5353,17 @@ autodiff_declaration_handler: type = {
53295353
diff.add_forward("(fwd_pass_style)$ (name)$ : (param.get_declaration().type())$ = 0.0, ");
53305354
diff.add_forward("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$ = 0.0, ");
53315355

5332-
diff.add_reverse("(fwd_pass_style)$ (name)$ : (param.get_declaration().type())$ = 0.0, ");
5356+
diff.add_reverse_primal("(fwd_pass_style)$ (name)$ : (param.get_declaration().type())$ = 0.0, ");
53335357
if ctx*.is_taylor() {
5334-
diff.add_reverse("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$ = 0.0, ");
5358+
diff.add_reverse_primal("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$ = 0.0, ");
53355359
}
53365360

53375361
ctx*.add_variable_declaration("(name)$", "(type)$");
53385362
}
53395363
}
53405364

53415365
diff.add_forward(") = {");
5342-
diff.add_reverse(") = {");
5366+
diff.add_reverse_primal(") = {");
53435367

53445368
// Generate the body
53455369

@@ -5357,18 +5381,19 @@ autodiff_declaration_handler: type = {
53575381
ad_impl..pre_traverse(stmt);
53585382
}
53595383
diff.add_forward(ad_impl.diff.fwd);
5360-
diff.add_reverse(ad_impl.diff.rws);
5384+
diff.add_reverse_primal(ad_impl.diff.rws_primal);
5385+
diff.add_reverse_primal(ad_impl.diff.rws_backprop);
53615386

53625387
diff.add_forward("}");
5363-
diff.add_reverse("}");
5388+
diff.add_reverse_primal("}");
53645389

53655390
ctx*.leave_function();
53665391

53675392
if ctx*.is_forward() {
53685393
decl.add_member( diff.fwd );
53695394
}
53705395
if ctx*.is_reverse() {
5371-
decl.add_member( diff.rws );
5396+
decl.add_member( diff.rws_primal );
53725397
}
53735398
diff.reset();
53745399

0 commit comments

Comments
 (0)