@@ -4383,19 +4383,22 @@ autodiff_context: type = {
43834383autodiff_diff_code: type = {
43844384 public ctx: *autodiff_context;
43854385
4386- public fwd: std::string = "";
4387- public rws: std::string = "";
4386+ public fwd : std::string = "";
4387+ public rws_primal : std::string = "";
4388+ public rws_backprop: std::string = "";
43884389
43894390 operator=:(out this, ctx_: *autodiff_context) = {
43904391 ctx = ctx_;
43914392 }
43924393
4393- add_forward: (inout this, v: std::string) = { if ctx*.is_forward() { fwd += v; }}
4394- add_reverse: (inout this, v: std::string) = { if ctx*.is_reverse() { rws += v; }}
4394+ add_forward : (inout this, v: std::string) = { if ctx*.is_forward() { fwd += v; }}
4395+ add_reverse_primal : (inout this, v: std::string) = { if ctx*.is_reverse() { rws_primal += v; }}
4396+ add_reverse_backprop: (inout this, v: std::string) = { if ctx*.is_reverse() { rws_backprop += v; }}
43954397
43964398 reset: (inout this) = {
4397- fwd = "";
4398- rws = "";
4399+ fwd = "";
4400+ rws_primal = "";
4401+ rws_backprop = "";
43994402 }
44004403
44014404 // Temporary: TODO: remove when everything has been adapted to primal, fwd, rws pushes.
@@ -4437,7 +4440,9 @@ autodiff_handler_base: type = {
44374440
44384441 // Temporary: TODO: remove when everything has been adapted to primal, fwd, rws pushes.
44394442 append: (inout this, in_ref o: autodiff_handler_base) = {
4440- diff += o.diff.fwd;
4443+ diff.fwd += o.diff.fwd;
4444+ diff.rws_primal += o.diff.rws_primal;
4445+ diff.rws_backprop = o.diff.rws_backprop + diff.rws_backprop;
44414446 }
44424447}
44434448
@@ -4449,28 +4454,35 @@ autodiff_expression_handler: type = {
44494454
44504455 public primal_expr: std::string = "";
44514456 public fwd_expr : std::string = "";
4457+ public rws_expr : std::string = "";
44524458
44534459 operator=: (out this, ctx_: *autodiff_context) = {
44544460 autodiff_handler_base = (ctx_);
44554461 }
44564462
4457- add_suffix_if_not_wildcard: (this, lhs: std::string) -> std::string = {
4463+ add_suffix_if_not_wildcard: (this, lhs: std::string, suffix: std::string ) -> std::string = {
44584464 if "_" == lhs {
44594465 return lhs;
44604466 }
44614467 else {
4462- return lhs + ctx*.fwd_suffix ;
4468+ return lhs + suffix ;
44634469 }
44644470 }
44654471
4466- gen_assignment: (inout this, lhs: std::string, lhs_d: std::string, rhs: std::string, rhs_d: std::string) = {
4467- diff += "(lhs_d)$ = (rhs_d)$;\n";
4468- diff += "(lhs)$ = (rhs)$;\n";
4472+ gen_assignment: (inout this, lhs: std::string, lhs_d: std::string, lhs_b: std::string, rhs: std::string, rhs_d: std::string, rhs_b: std::string) = {
4473+ diff.add_forward("(lhs_d)$ = (rhs_d)$;\n");
4474+ diff.add_forward("(lhs)$ = (rhs)$;\n");
4475+
4476+ if ctx*.is_taylor() {
4477+ diff.add_reverse_primal("(lhs_d)$ = (rhs_d)$;\n");
4478+ }
4479+ diff.add_reverse_primal("(lhs)$ = (rhs)$;\n");
4480+ diff.add_reverse_backprop(string_util::replace_all(rhs_b, "_rb_", lhs_b));
44694481 }
4470- gen_assignment: (inout this, lhs: std::string, lhs_d: std::string)
4471- = gen_assignment(lhs, lhs_d, primal_expr, fwd_expr);
4482+ gen_assignment: (inout this, lhs: std::string, lhs_d: std::string, lhs_b: std::string )
4483+ = gen_assignment(lhs, lhs_d, lhs_b, primal_expr, fwd_expr, rws_expr );
44724484 gen_assignment: (inout this, lhs: std::string)
4473- = gen_assignment(lhs, add_suffix_if_not_wildcard(lhs), primal_expr, fwd_expr);
4485+ = gen_assignment(lhs, add_suffix_if_not_wildcard(lhs, ctx*.fwd_suffix ), add_suffix_if_not_wildcard(lhs, ctx*.rws_suffix), primal_expr, fwd_expr, rws_expr );
44744486
44754487
44764488 gen_declaration: (inout this, lhs: std::string, lhs_d: std::string, rhs: std::string, rhs_d: std::string, type: std::string, type_d: std::string) = {
@@ -4486,43 +4498,46 @@ autodiff_expression_handler: type = {
44864498
44874499
44884500
4489- primal_fwd_name : @struct type = {
4501+ primal_fwd_rws_name : @struct type = {
44904502 primal: std::string = "";
44914503 fwd : std::string = "";
4504+ rws : std::string = "";
44924505 }
44934506
4494- handle_expression_list: (inout this, list: meta::expression_list) -> std::vector<primal_fwd_name > = {
4495- args : std::vector<primal_fwd_name > = ();
4507+ handle_expression_list: (inout this, list: meta::expression_list) -> std::vector<primal_fwd_rws_name > = {
4508+ args : std::vector<primal_fwd_rws_name > = ();
44964509 for list.get_expressions() do (expr) {
44974510 args.push_back(handle_expression_term(expr));
44984511 }
44994512
45004513 return args;
45014514 }
45024515
4503- handle_expression_term :(inout this, term) -> primal_fwd_name = {
4516+ handle_expression_term :(inout this, term) -> primal_fwd_rws_name = {
45044517 if term.is_identifier() {
45054518 primal := term.to_string();
45064519 fwd := primal + ctx*.fwd_suffix;
4520+ rws := primal + ctx*.rws_suffix;
45074521
45084522 decl := ctx*.lookup_variable_declaration(primal);
45094523 if decl.is_member {
45104524 fwd = "this(ctx*.fwd_suffix)$." + fwd;
4525+ rws = "this(ctx*.rws_suffix)$." + rws;
45114526 }
4512- return (primal, fwd);
4527+ return (primal, fwd, rws );
45134528 }
45144529 else if term.is_expression_list() {
45154530 exprs := term..as_expression_list()..get_expressions();
45164531 if exprs.ssize() != 1 {
45174532 term.error("Can not handle multiple expressions. (term.to_string())");
4518- return ("error", "");
4533+ return ("error", "", "" );
45194534 }
45204535 expr := exprs[0];
45214536 bin_expr := expr..as_assignment_expression();
45224537
45234538 if bin_expr.terms_size() != 0 {
45244539 term.error("Can not handle assign expr inside of expression. (expr.to_string())$");
4525- return ("error", "");
4540+ return ("error", "", "" );
45264541 }
45274542
45284543 ad : autodiff_expression_handler = (ctx);
@@ -4531,7 +4546,7 @@ autodiff_expression_handler: type = {
45314546 ad.gen_declaration(t, "double"); // TODO: get type of expression
45324547 append(ad);
45334548
4534- r : primal_fwd_name = (t, t + ctx*.fwd_suffix); // TODO: Check why on return (t, t + ctx*.fwd_suffix) the primal is initialized empty. Probably because of the move(t)
4549+ r : primal_fwd_rws_name = (t, t + ctx*.fwd_suffix, t + ctx*.rws_suffix ); // TODO: Check why on return (t, t + ctx*.fwd_suffix) the primal is initialized empty. Probably because of the move(t)
45354550 _ = t;
45364551 return r;
45374552 }
@@ -4545,7 +4560,7 @@ autodiff_expression_handler: type = {
45454560 ad.gen_declaration(t, "double"); // TODO: get type of expression
45464561 append(ad);
45474562
4548- r : primal_fwd_name = (t, t + ctx*.fwd_suffix); // TODO: Check why on return (t, t + ctx*.fwd_suffix) the primal is initialized empty. Probably because of the move(t)
4563+ r : primal_fwd_rws_name = (t, t + ctx*.fwd_suffix, t + ctx*.rws_suffix ); // TODO: Check why on return (t, t + ctx*.fwd_suffix) the primal is initialized empty. Probably because of the move(t)
45494564 _ = t;
45504565 return r;
45514566 }
@@ -4577,7 +4592,7 @@ autodiff_expression_handler: type = {
45774592 object : std::string = "";
45784593 object_d : std::string = "";
45794594 function_name : std::string = "";
4580- args : std::vector<primal_fwd_name > = ();
4595+ args : std::vector<primal_fwd_rws_name > = ();
45814596
45824597 primary := postfix.get_primary_expression();
45834598
@@ -4695,7 +4710,7 @@ autodiff_expression_handler: type = {
46954710 // TODO: Add function to list of functions/objects for differentiation for the no return case.
46964711 }
46974712
4698- handle_special_function: (inout this, object: std::string, object_d: std::string, function_name: std::string, args: std::vector<primal_fwd_name >) -> bool = {
4713+ handle_special_function: (inout this, object: std::string, object_d: std::string, function_name: std::string, args: std::vector<primal_fwd_rws_name >) -> bool = {
46994714
47004715 r := ctx*.lookup_special_function_handling(function_name, args.ssize(), !object.empty());
47014716
@@ -4778,24 +4793,29 @@ autodiff_expression_handler: type = {
47784793 terms := binexpr.get_terms();
47794794
47804795 first := true;
4796+ op : std::string = "+";
47814797 fwd : std::string = "";
4798+ rws : std::string = "";
47824799 primal: std::string = "";
47834800 for terms do (term) {
47844801 if !first {
4785- op : = term.get_op().to_string();
4802+ op = term.get_op().to_string();
47864803 fwd += " (op)$ ";
47874804 primal += " (op)$ ";
4805+
47884806 }
47894807
47904808 var := handle_expression_term(term.get_term());
47914809 fwd += var.fwd;
4810+ rws += "(var.rws)$ (op)$= _rb_;\n";
47924811 primal += var.primal;
47934812
47944813 first = false;
47954814 }
47964815
47974816 primal_expr = primal;
47984817 fwd_expr = fwd;
4818+ rws_expr = rws;
47994819 }
48004820
48014821 traverse: (override inout this, binexpr: meta::multiplicative_expression) = {
@@ -4919,17 +4939,20 @@ autodiff_expression_handler: type = {
49194939 {
49204940 if primary.is_identifier() {
49214941 primal_expr = primary.to_string();
4922- fwd_expr = add_suffix_if_not_wildcard(primal_expr);
4942+ fwd_expr = add_suffix_if_not_wildcard(primal_expr, ctx*.fwd_suffix);
4943+ rws_expr = add_suffix_if_not_wildcard(primal_expr, ctx*.rws_suffix);
49234944
49244945 decl := ctx*.lookup_variable_declaration(primal_expr);
49254946 if decl.is_member {
49264947 fwd_expr = "this(ctx*.fwd_suffix)$." + fwd_expr;
4948+ rws_expr = "this(ctx*.rws_suffix)$." + rws_expr;
49274949 }
49284950 }
49294951 else if primary.is_expression_list() {
49304952 if primary.as_expression_list().is_empty() {
49314953 primal_expr = "()";
49324954 fwd_expr = "()";
4955+ rws_expr = "()"; // TODO: Check for reverse
49334956 }
49344957 else {
49354958 primary.error("AD: Do not know how to handle non empty expression list inside of primary_expression: (primary.to_string())$");
@@ -4938,6 +4961,7 @@ autodiff_expression_handler: type = {
49384961 else if primary.is_literal() {
49394962 primal_expr = primary.to_string();
49404963 fwd_expr = "()";
4964+ rws_expr = "()"; // TODO: Check for reverse
49414965 }
49424966 else if primary.is_declaration() {
49434967 primary.error("AD: Do not know how to handle declaration inside of primary_expression: (primary.to_string())$");
@@ -5114,7 +5138,7 @@ autodiff_stmt_handler: type = {
51145138
51155139 h: autodiff_expression_handler = (ctx);
51165140 h.pre_traverse(assignment_terms[1].get_term());
5117- h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr);
5141+ h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, h_lhs.rws_expr );
51185142 append(h);
51195143 }
51205144
@@ -5234,7 +5258,7 @@ autodiff_declaration_handler: type = {
52345258 ctx*.enter_function();
52355259
52365260 diff.add_forward(" (f.name())$(ctx*.fwd_suffix)$: (");
5237- diff.add_reverse (" (f.name())$(ctx*.rws_suffix)$: (");
5261+ diff.add_reverse_primal (" (f.name())$(ctx*.rws_suffix)$: (");
52385262
52395263 // 1. Generate the modified signature
52405264 // a) Parameters
@@ -5252,22 +5276,22 @@ autodiff_declaration_handler: type = {
52525276 diff.add_forward("(fwd_pass_style)$ (name)$, ");
52535277 diff.add_forward("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$: (fwd_ad_type)$, ");
52545278
5255- diff.add_reverse ("(fwd_pass_style)$ (name)$, ");
5279+ diff.add_reverse_primal ("(fwd_pass_style)$ (name)$, ");
52565280 if ctx*.is_taylor() { // Add forward type for higher order
5257- diff.add_reverse ("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$: (fwd_ad_type)$, ");
5281+ diff.add_reverse_primal ("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$: (fwd_ad_type)$, ");
52585282 }
5259- diff.add_reverse ("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$: (rws_ad_type)$, ");
5283+ diff.add_reverse_primal ("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$: (rws_ad_type)$, ");
52605284 }
52615285 else {
52625286 type := param.get_declaration().type();
52635287 diff.add_forward("(fwd_pass_style)$ (name)$ : (type)$, ");
52645288 diff.add_forward("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$, ");
52655289
5266- diff.add_reverse ("(fwd_pass_style)$ (name)$ : (type)$, ");
5290+ diff.add_reverse_primal ("(fwd_pass_style)$ (name)$ : (type)$, ");
52675291 if ctx*.is_taylor() {
5268- diff.add_reverse ("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$, ");
5292+ diff.add_reverse_primal ("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$, ");
52695293 }
5270- diff.add_reverse ("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$ : (ctx*.get_rws_ad_type(type))$, ");
5294+ diff.add_reverse_primal ("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$ : (ctx*.get_rws_ad_type(type))$, ");
52715295
52725296
52735297 ctx*.add_variable_declaration(name, type);
@@ -5280,10 +5304,10 @@ autodiff_declaration_handler: type = {
52805304 // TODO: check if name "r" is available. (Also needs inspection of functions at call sides.)
52815305 if f.has_deduced_return_type() {
52825306 // TODO: Take care of initialization order error.
5283- diff.add_reverse ("inout r(ctx*.rws_suffix)$, ");
5307+ diff.add_reverse_primal ("inout r(ctx*.rws_suffix)$, ");
52845308 }
52855309 else {
5286- diff.add_reverse ("inout r(ctx*.rws_suffix)$: (ctx*.get_rws_ad_type(f.get_unnamed_return_type()))$, ");
5310+ diff.add_reverse_primal ("inout r(ctx*.rws_suffix)$: (ctx*.get_rws_ad_type(f.get_unnamed_return_type()))$, ");
52875311 }
52885312 }
52895313 else {
@@ -5292,12 +5316,12 @@ autodiff_declaration_handler: type = {
52925316 type := param.get_declaration().type();
52935317
52945318 rws_pass_style := to_string_view(ctx*.get_reverse_passing_style(param.get_passing_style()));
5295- diff.add_reverse ("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$ : (ctx*.get_rws_ad_type(type))$ , ");
5319+ diff.add_reverse_primal ("(rws_pass_style)$ (name)$(ctx*.rws_suffix)$ : (ctx*.get_rws_ad_type(type))$ , ");
52965320 }
52975321 }
52985322
52995323 diff.add_forward(") -> (");
5300- diff.add_reverse (") -> (");
5324+ diff.add_reverse_primal (") -> (");
53015325
53025326 // c) Returns
53035327
@@ -5306,16 +5330,16 @@ autodiff_declaration_handler: type = {
53065330 if f.has_deduced_return_type() {
53075331 // TODO: Take care of initialization order error.
53085332 diff.add_forward("r, r(ctx*.fwd_suffix)$, ");
5309- diff.add_reverse ("r, ");
5333+ diff.add_reverse_primal ("r, ");
53105334 if ctx*.is_taylor() {
5311- diff.add_reverse ("r(ctx*.fwd_suffix)$,");
5335+ diff.add_reverse_primal ("r(ctx*.fwd_suffix)$,");
53125336 }
53135337 }
53145338 else {
53155339 diff.add_forward("r: (f.get_unnamed_return_type())$ = (), r(ctx*.fwd_suffix)$: (ctx*.get_fwd_ad_type(f.get_unnamed_return_type()))$ = (), ");
5316- diff.add_reverse ("r: (f.get_unnamed_return_type())$ = (), ");
5340+ diff.add_reverse_primal ("r: (f.get_unnamed_return_type())$ = (), ");
53175341 if ctx*.is_taylor() {
5318- diff.add_reverse ("r(ctx*.fwd_suffix)$: (ctx*.get_fwd_ad_type(f.get_unnamed_return_type()))$ = (), ");
5342+ diff.add_reverse_primal ("r(ctx*.fwd_suffix)$: (ctx*.get_fwd_ad_type(f.get_unnamed_return_type()))$ = (), ");
53195343 }
53205344 }
53215345 }
@@ -5329,17 +5353,17 @@ autodiff_declaration_handler: type = {
53295353 diff.add_forward("(fwd_pass_style)$ (name)$ : (param.get_declaration().type())$ = 0.0, ");
53305354 diff.add_forward("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$ = 0.0, ");
53315355
5332- diff.add_reverse ("(fwd_pass_style)$ (name)$ : (param.get_declaration().type())$ = 0.0, ");
5356+ diff.add_reverse_primal ("(fwd_pass_style)$ (name)$ : (param.get_declaration().type())$ = 0.0, ");
53335357 if ctx*.is_taylor() {
5334- diff.add_reverse ("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$ = 0.0, ");
5358+ diff.add_reverse_primal ("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$ = 0.0, ");
53355359 }
53365360
53375361 ctx*.add_variable_declaration("(name)$", "(type)$");
53385362 }
53395363 }
53405364
53415365 diff.add_forward(") = {");
5342- diff.add_reverse (") = {");
5366+ diff.add_reverse_primal (") = {");
53435367
53445368 // Generate the body
53455369
@@ -5357,18 +5381,19 @@ autodiff_declaration_handler: type = {
53575381 ad_impl..pre_traverse(stmt);
53585382 }
53595383 diff.add_forward(ad_impl.diff.fwd);
5360- diff.add_reverse(ad_impl.diff.rws);
5384+ diff.add_reverse_primal(ad_impl.diff.rws_primal);
5385+ diff.add_reverse_primal(ad_impl.diff.rws_backprop);
53615386
53625387 diff.add_forward("}");
5363- diff.add_reverse ("}");
5388+ diff.add_reverse_primal ("}");
53645389
53655390 ctx*.leave_function();
53665391
53675392 if ctx*.is_forward() {
53685393 decl.add_member( diff.fwd );
53695394 }
53705395 if ctx*.is_reverse() {
5371- decl.add_member( diff.rws );
5396+ decl.add_member( diff.rws_primal );
53725397 }
53735398 diff.reset();
53745399
0 commit comments