@@ -4592,6 +4592,7 @@ autodiff_diff_code: type = {
45924592 operator=:(out this, ctx_: *autodiff_context) = {
45934593 ctx = ctx_;
45944594 }
4595+ operator=:(out this, that) = {}
45954596
45964597 add_forward : (inout this, v: std::string) = { if ctx*.is_forward() { fwd += v; }}
45974598 add_reverse_primal : (inout this, v: std::string) = { if ctx*.is_reverse() { rws_primal += v; }}
@@ -4777,6 +4778,7 @@ autodiff_expression_handler: type = {
47774778
47784779 return r;
47794780 }
4781+ prepare_backprop: (this, rhs_b: std::string, lhs: std::string) -> std::string = prepare_backprop(rhs_b, lhs, lhs + ctx*.fwd_suffix, lhs + ctx*.rws_suffix);
47804782
47814783 gen_assignment: (inout this, lhs: std::string, lhs_d: std::string, lhs_b: std::string, rhs: std::string, rhs_d: std::string, rhs_b: std::string) = {
47824784 diff.add_forward("(lhs_d)$ = (rhs_d)$;\n");
@@ -5386,18 +5388,21 @@ autodiff_stmt_handler: type = {
53865388 mf: meta::function_declaration;
53875389
53885390 last_params: std::vector<meta::parameter_declaration> = ();
5391+ overwritten: std::vector<std::string> = ();
5392+
5393+ overwrite_push_pop: bool = false;
53895394
53905395 operator=: (out this, ctx_: *autodiff_context, mf_: meta::function_declaration) = {
53915396 autodiff_handler_base = (ctx_);
53925397 mf = mf_;
53935398 }
53945399
5395- handle_stmt_parameters: (inout this, params: std::vector<parameter_declaration>, leave_open: bool) = {
5400+ handle_stmt_parameters: (inout this, params: std::vector<parameter_declaration>) -> autodiff_diff_code = {
5401+ r : autodiff_diff_code = (ctx);
53965402 if params.empty() {
5397- return;
5403+ return r ;
53985404 }
53995405
5400- fwd: std::string = "(";
54015406 for params do (param) {
54025407 name: std::string = param.get_declaration().name();
54035408 type: std::string = param.get_declaration().type();
@@ -5409,6 +5414,7 @@ autodiff_stmt_handler: type = {
54095414
54105415 init : std::string = "";
54115416 init_d: std::string = "";
5417+ // TODO: Add handling for reverse expressions
54125418
54135419 if param.get_declaration().has_initializer() {
54145420 ad: autodiff_expression_handler = (ctx);
@@ -5421,19 +5427,16 @@ autodiff_stmt_handler: type = {
54215427 }
54225428
54235429
5424- fwd += "(fwd_pass_style)$ (name)$ : (type)$(init)$, ";
5430+ r.add_forward("(fwd_pass_style)$ (name)$ : (type)$(init)$, ");
5431+ r.add_reverse_primal("(fwd_pass_style)$ (name)$ : (type)$(init)$, ");
54255432 if ada.active {
5426- fwd += "(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$(init_d)$, ";
5433+ r.add_forward( "(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$(init_d)$, ") ;
54275434 }
54285435
54295436 ctx*.add_variable_declaration(name, type, ada.active);
54305437 }
54315438
5432- if !leave_open {
5433- fwd += ")";
5434- }
5435-
5436- diff += fwd;
5439+ return r;
54375440 }
54385441
54395442 traverse: (override inout this, decl: meta::declaration) = {
@@ -5465,9 +5468,11 @@ autodiff_stmt_handler: type = {
54655468 if active {
54665469
54675470 fwd_ad_type : = ctx*.get_fwd_ad_type(type);
5471+ rws_ad_type : = ctx*.get_rws_ad_type(type);
54685472
54695473 prim_init: std::string = "";
54705474 fwd_init : std::string = "";
5475+ rws_init : std::string = "";
54715476
54725477 if o.has_initializer() {
54735478 ad: autodiff_expression_handler = (ctx);
@@ -5476,14 +5481,23 @@ autodiff_stmt_handler: type = {
54765481
54775482 prim_init = " = " + ad.primal_expr;
54785483 fwd_init = " = " + ad.fwd_expr;
5484+ rws_init = " = ()"; // TODO: Proper initialization.
5485+
5486+ if ad.rws_expr != "()" {
5487+ diff.add_reverse_backprop(ad.prepare_backprop(ad.rws_expr, lhs));
5488+ }
54795489
54805490 if type == "_" && ad.fwd_expr == "()" {
54815491 // Special handling for auto initialization from a literal.
54825492 fwd_init = " = " + ctx*.get_fwd_ad_type("double") + "()";
54835493 }
54845494 }
5485- diff += "(lhs)$(ctx*.fwd_suffix)$ : (fwd_ad_type)$(fwd_init)$;\n";
5486- diff += "(lhs)$ : (type)$(prim_init)$;\n";
5495+
5496+ diff.add_forward("(lhs)$(ctx*.fwd_suffix)$ : (fwd_ad_type)$(fwd_init)$;\n");
5497+ diff.add_forward("(lhs)$ : (type)$(prim_init)$;\n");
5498+
5499+ diff.add_reverse_primal("(lhs)$(ctx*.rws_suffix)$ : (rws_ad_type)$(rws_init)$;\n");
5500+ diff.add_reverse_primal("(lhs)$ : (type)$(prim_init)$;\n");
54875501 }
54885502 else {
54895503 diff += "(lhs)$: (type)$";
@@ -5515,9 +5529,37 @@ autodiff_stmt_handler: type = {
55155529
55165530
55175531 traverse: (override inout this, stmt: meta::compound_statement) = {
5518- diff += "{\n";
5519- base::traverse(stmt);
5520- diff += "}\n";
5532+ ad : autodiff_stmt_handler = (ctx, mf);
5533+ ad_push_pop: autodiff_stmt_handler = (ctx, mf);
5534+ ad_push_pop.overwrite_push_pop = true;
5535+
5536+ diff.add_forward("{\n");
5537+ diff.add_reverse_primal("{\n");
5538+ diff.add_reverse_backprop("}\n");
5539+
5540+ for stmt.get_statements() do (cur) {
5541+ ad.pre_traverse(cur);
5542+ ad_push_pop.pre_traverse(cur);
5543+ }
5544+
5545+ for ad.overwritten do (cur) {
5546+ r := ctx*.lookup_variable_declaration(cur);
5547+ diff.add_reverse_primal("cpp2::ad_stack::push<(r.decl)$>((cur)$);");
5548+ }
5549+
5550+ diff.add_forward(ad.diff.fwd);
5551+ diff.add_reverse_primal(ad.diff.rws_primal);
5552+ diff.add_reverse_backprop(ad_push_pop.diff.rws_backprop);
5553+ diff.add_reverse_backprop(ad_push_pop.diff.rws_primal);
5554+
5555+ for ad.overwritten do (cur) {
5556+ r := ctx*.lookup_variable_declaration(cur);
5557+ diff.add_reverse_backprop("(cur)$ = cpp2::ad_stack::pop<(r.decl)$>();");
5558+ }
5559+
5560+ diff.add_forward("}\n");
5561+ diff.add_reverse_primal("}\n");
5562+ diff.add_reverse_backprop("{\n");
55215563 }
55225564
55235565
@@ -5537,13 +5579,32 @@ autodiff_stmt_handler: type = {
55375579 }
55385580 }
55395581
5582+ reverse_next: (this, expr: std::string) -> std::string = {
5583+ if expr.contains("+=") {
5584+ return string_util::replace_all(expr, "+=", "-=");
5585+ }
5586+ else if expr.contains("-=") {
5587+ return string_util::replace_all(expr, "-=", "+=");
5588+ }
5589+
5590+ mf.error("AD: Do not know how to reverse: (expr)$");
5591+
5592+ return "Error";
5593+
5594+ }
5595+
55405596
55415597 traverse: (override inout this, stmt: meta::iteration_statement) = {
5542- if !last_params.empty() {
5543- handle_stmt_parameters(last_params, stmt.is_for());
5598+ diff_params := handle_stmt_parameters(last_params);
5599+
5600+ if ctx*.is_reverse() && (stmt.is_while() || stmt.is_do()) {
5601+ stmt.error("AD: Alpha limitiation now reverse mode for while or do while.");
55445602 }
55455603
55465604 if stmt.is_while() {
5605+ if !last_params.empty() {
5606+ diff.add_forward("(" + diff_params.fwd + ")");
5607+ }
55475608 // TODO: Assumption is here that nothing is in the condition
55485609 diff += "while (stmt.get_do_while_condition().to_string())$ ";
55495610 if stmt.has_next() {
@@ -5554,6 +5615,10 @@ autodiff_stmt_handler: type = {
55545615 pre_traverse(stmt.get_do_while_body());
55555616 }
55565617 else if stmt.is_do() {
5618+ if !last_params.empty() {
5619+ diff.add_forward("(" + diff_params.fwd + ")");
5620+ }
5621+
55575622 // TODO: Assumption is here that nothing is in the condition
55585623 diff += "do ";
55595624 pre_traverse(stmt.get_do_while_body());
@@ -5574,23 +5639,55 @@ autodiff_stmt_handler: type = {
55745639 param := stmt.get_for_parameter();
55755640 param_style := to_string_view(param.get_passing_style());
55765641 param_decl := param.get_declaration();
5577- if last_params.empty() {
5578- diff += "("; // Open statment parameter scope. If the loop has parameters, they are alrady handled and the brace is left open.
5642+
5643+ rws : std::string = "(";
5644+ rws_restore: std::string = "";
5645+ diff.add_forward("("); // Open statment parameter scope. If the loop has parameters, they are alrady handled and the brace is left open.
5646+ diff.add_reverse_primal("{\n");
5647+ if !last_params.empty() {
5648+ for last_params do (cur) {
5649+ if cur.get_declaration().has_initializer() {
5650+ // TODO: Handle no type and no initializer. Handle passing style.
5651+ diff.add_reverse_primal("(cur.get_declaration().name())$: (cur.get_declaration().type())$ = (cur.get_declaration().get_initializer().to_string())$;\n");
5652+ rws_restore += "cpp2::ad_stack::push<(cur.get_declaration().type())$>((cur.get_declaration().name())$);\n";
5653+ rws += "(to_string_view(cur.get_passing_style()))$ (cur.get_declaration().name())$: (cur.get_declaration().type())$ = cpp2::ad_stack::pop<(cur.get_declaration().type())$>(), ";
5654+ }
5655+ }
5656+ diff.add_forward(diff_params.fwd);
55795657 }
5580- diff += "copy (param_decl.name())$_d_iter := (range)$(ctx*.fwd_suffix)$.begin())\n";
5581- diff += "for (range)$ next (";
5658+ diff.add_forward("copy (param_decl.name())$(ctx*.fwd_suffix)$_iter := (range)$(ctx*.fwd_suffix)$.begin())\n");
5659+ diff.add_forward("for (range)$ next (");
5660+
5661+ rws += "copy (param_decl.name())$(ctx*.rws_suffix)$_iter := (range)$(ctx*.rws_suffix)$.rbegin())\n";
5662+ rws += "for std::ranges::reverse_view((range)$) next (";
5663+ diff.add_reverse_primal("for (range)$ next (");
55825664 if stmt.has_next() {
55835665 // TODO: Assumption is here that nothing is in the next expression
5584- diff += "(stmt.get_next_expression().to_string())$, ";
5666+ diff.add_forward("(stmt.get_next_expression().to_string())$, ");
5667+ diff.add_reverse_primal("(stmt.get_next_expression().to_string())$, ");
5668+ rws += "(reverse_next(stmt.get_next_expression().to_string()))$, ";
55855669 }
5586- diff += "(param_decl.name())$_d_iter++";
5587- diff += ") do ((param_style)$ (param_decl.name())$: (param_decl.type())$) {\n";
5588- diff += "((param_style)$ (param_decl.name())$(ctx*.fwd_suffix)$: (param_decl.type())$ = (param_decl.name())$_d_iter*)";
5670+ diff.add_forward("(param_decl.name())$(ctx*.fwd_suffix)$_iter++");
5671+ diff.add_forward(") do ((param_style)$ (param_decl.name())$: (param_decl.type())$) {\n");
5672+ rws += "(param_decl.name())$(ctx*.rws_suffix)$_iter++";
5673+ rws += ") do ((param_style)$ (param_decl.name())$: (param_decl.type())$) {\n";
5674+ rws += "(inout (param_decl.name())$(ctx*.rws_suffix)$ := (param_decl.name())$(ctx*.rws_suffix)$_iter*)\n";
5675+
5676+ diff.add_reverse_primal(") do ((param_style)$ (param_decl.name())$: (param_decl.type())$)");
5677+ diff.add_forward("((param_style)$ (param_decl.name())$(ctx*.fwd_suffix)$: (param_decl.type())$ = (param_decl.name())$(ctx*.fwd_suffix)$_iter*)");
55895678
55905679 ctx*.add_variable_declaration("(param_decl.name())$", "(param_decl.type())$", true); // TODO: Handle loop/compound context variable declarations.
5680+ diff.add_reverse_backprop("}\n");
55915681
55925682 pre_traverse(stmt.get_for_body());
5593- diff += "}\n";
5683+ diff.add_forward("}\n");
5684+
5685+ if stmt.has_next() {
5686+ diff.add_reverse_primal("(reverse_next(stmt.get_next_expression().to_string()))$;\n");
5687+ }
5688+ diff.add_reverse_primal(rws_restore);
5689+ diff.add_reverse_primal("}\n");
5690+ diff.add_reverse_backprop(rws);
55945691 }
55955692 }
55965693
@@ -5626,8 +5723,34 @@ autodiff_stmt_handler: type = {
56265723
56275724 h: autodiff_expression_handler = (ctx);
56285725 h.pre_traverse(assignment_terms[1].get_term());
5629- h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, h_lhs.rws_expr);
5630- append(h);
5726+
5727+ is_overwrite := h.primal_expr.contains(h_lhs.primal_expr);
5728+ if overwrite_push_pop && is_overwrite {
5729+ r := ctx*.lookup_variable_declaration(h_lhs.primal_expr);
5730+ diff.add_reverse_primal("cpp2::ad_stack::push<(r.decl)$>((h_lhs.primal_expr)$);");
5731+ }
5732+
5733+ if is_overwrite && ctx*.is_reverse() {
5734+ t_b := ctx*.gen_temporary() + ctx*.rws_suffix;
5735+ h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, t_b);
5736+ append(h);
5737+ diff.add_reverse_backprop("(h_lhs.rws_expr)$ = 0.0;\n");
5738+ diff.add_reverse_backprop("(t_b)$ := (h_lhs.rws_expr)$;\n");
5739+ }
5740+ else {
5741+ h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, h_lhs.rws_expr);
5742+ append(h);
5743+ }
5744+
5745+ if overwrite_push_pop && is_overwrite {
5746+ r := ctx*.lookup_variable_declaration(h_lhs.primal_expr);
5747+ diff.add_reverse_backprop("(h_lhs.primal_expr)$ = cpp2::ad_stack::pop<(r.decl)$>();");
5748+ }
5749+
5750+ // Simple overwrite check
5751+ if is_overwrite {
5752+ overwritten.push_back(h_lhs.primal_expr);
5753+ }
56315754 }
56325755 else {
56335756 diff.add_forward(binexpr.to_string() + ";\n");
@@ -6046,6 +6169,9 @@ autodiff: (inout t: meta::type_declaration) =
60466169 if 1 != order {
60476170 t.add_runtime_support_include( "cpp2taylor.h" );
60486171 }
6172+ if reverse {
6173+ t.add_runtime_support_include( "cpp2ad_stack.h" );
6174+ }
60496175
60506176 ad_ctx.finish();
60516177
0 commit comments