Skip to content

Commit 9a411dc

Browse files
authored
handle DualOnly ret more reliably (#113)
1 parent 17c772f commit 9a411dc

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -461,12 +461,20 @@ fn gen_enzyme_body(
461461
};
462462
}
463463

464-
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, exprs);
465-
let ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]);
466-
if d_sig.decl.output.has_ret() {
467-
// If we return (), we don't have to match the return type.
468-
body.stmts.push(ecx.stmt_expr(ret));
464+
let ret : P<ast::Expr>;
465+
if exprs.len() > 1 {
466+
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, exprs);
467+
ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]);
468+
} else if exprs.len() == 1 {
469+
let ret_scal = exprs.pop().unwrap();
470+
ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_scal]);
471+
} else {
472+
assert!(!d_sig.decl.output.has_ret());
473+
// We don't have to match the return type.
474+
return body;
469475
}
476+
assert!(d_sig.decl.output.has_ret());
477+
body.stmts.push(ecx.stmt_expr(ret));
470478

471479
body
472480
}

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2762,6 +2762,7 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<Diff
27622762
// We care about safety checks, if an argument get's duplicated and we write into the
27632763
// shadow. That's equivalent to Duplicated or DuplicatedOnly.
27642764
let safety = if !da.is_empty() {
2765+
assert!(da.len() == x.inputs().len(), "{:?} != {:?}", da.len(), x.inputs().len());
27652766
// If we have Activities, we also have spans
27662767
assert!(span.is_some());
27672768
match da[i] {

0 commit comments

Comments
 (0)