Skip to content

Commit 17c772f

Browse files
authored
unbreak ActiveOnly return (#112)
1 parent ffdbffb commit 17c772f

File tree

3 files changed

+72
-29
lines changed

3 files changed

+72
-29
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ impl FromStr for DiffActivity {
167167
match s {
168168
"None" => Ok(DiffActivity::None),
169169
"Active" => Ok(DiffActivity::Active),
170+
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
170171
"Const" => Ok(DiffActivity::Const),
171172
"Dual" => Ok(DiffActivity::Dual),
172173
"DualOnly" => Ok(DiffActivity::DualOnly),
@@ -192,6 +193,12 @@ impl AutoDiffAttrs {
192193
_ => true,
193194
}
194195
}
196+
pub fn has_active_only_ret(&self) -> bool {
197+
match self.ret_activity {
198+
DiffActivity::ActiveOnly => true,
199+
_ => false,
200+
}
201+
}
195202
}
196203

197204
impl AutoDiffAttrs {

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ pub fn expand(
111111
(sig.clone(), true)
112112
},
113113
_ => {
114-
dbg!(&item);
115114
ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
116115
return vec![item];
117116
}
@@ -280,7 +279,6 @@ pub fn expand(
280279
d_fn.vis = vis;
281280
Annotatable::Item(d_fn)
282281
};
283-
trace!("Generated function: {:?}", d_annotatable);
284282

285283
return vec![orig_annotatable, d_annotatable];
286284
}
@@ -371,7 +369,9 @@ fn gen_enzyme_body(
371369
return body;
372370
}
373371

374-
let primal_ret = sig.decl.output.has_ret();
372+
// having an active-only return means we'll drop the original return type.
373+
// So that can be treated identical to not having one in the first place.
374+
let primal_ret = sig.decl.output.has_ret() && !x.has_active_only_ret();
375375

376376
if primal_ret && n_active == 0 && is_rev(x.mode) {
377377
// We only have the primal ret.
@@ -405,16 +405,26 @@ fn gen_enzyme_body(
405405

406406
// Now construct default placeholder for each active float.
407407
// Is there something nicer than f32::default() and f64::default()?
408-
let mut d_ret_ty = match d_sig.decl.output {
408+
let d_ret_ty = match d_sig.decl.output {
409409
FnRetTy::Ty(ref ty) => ty.clone(),
410410
FnRetTy::Default(span) => {
411411
panic!("Did not expect Default ret ty: {:?}", span);
412412
}
413413
};
414-
let mut d_ret_ty = match d_ret_ty.kind {
415-
TyKind::Tup(ref mut tys) => {
414+
let mut d_ret_ty = match d_ret_ty.kind.clone() {
415+
TyKind::Tup(ref tys) => {
416416
tys.clone()
417417
}
418+
TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
419+
if segments.len() == 1 && segments[0].args.is_none() {
420+
let id = vec![segments[0].ident];
421+
let kind = TyKind::Path(None, ecx.path(span, id));
422+
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
423+
thin_vec![ty]
424+
} else {
425+
panic!("Expected tuple or simple path return type");
426+
}
427+
}
418428
_ => {
419429
// We messed up construction of d_sig
420430
panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
@@ -585,33 +595,41 @@ fn gen_enzyme_decl(
585595
}
586596
}
587597

598+
let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
599+
if active_only_ret {
600+
assert!(is_rev(x.mode));
601+
}
602+
588603
// If we return a scalar in the primal and the scalar is active,
589604
// then add it as last arg to the inputs.
590605
if is_rev(x.mode) {
591-
if let DiffActivity::Active = x.ret_activity {
592-
let ty = match d_decl.output {
593-
FnRetTy::Ty(ref ty) => ty.clone(),
594-
FnRetTy::Default(span) => {
595-
panic!("Did not expect Default ret ty: {:?}", span);
596-
}
597-
};
598-
let name = "dret".to_string();
599-
let ident = Ident::from_str_and_span(&name, ty.span);
600-
let shadow_arg = ast::Param {
601-
attrs: ThinVec::new(),
602-
ty: ty.clone(),
603-
pat: P(ast::Pat {
606+
match x.ret_activity {
607+
DiffActivity::Active | DiffActivity::ActiveOnly => {
608+
let ty = match d_decl.output {
609+
FnRetTy::Ty(ref ty) => ty.clone(),
610+
FnRetTy::Default(span) => {
611+
panic!("Did not expect Default ret ty: {:?}", span);
612+
}
613+
};
614+
let name = "dret".to_string();
615+
let ident = Ident::from_str_and_span(&name, ty.span);
616+
let shadow_arg = ast::Param {
617+
attrs: ThinVec::new(),
618+
ty: ty.clone(),
619+
pat: P(ast::Pat {
620+
id: ast::DUMMY_NODE_ID,
621+
kind: PatKind::Ident(BindingAnnotation::NONE, ident, None),
622+
span: ty.span,
623+
tokens: None,
624+
}),
604625
id: ast::DUMMY_NODE_ID,
605-
kind: PatKind::Ident(BindingAnnotation::NONE, ident, None),
606626
span: ty.span,
607-
tokens: None,
608-
}),
609-
id: ast::DUMMY_NODE_ID,
610-
span: ty.span,
611-
is_placeholder: false,
612-
};
613-
d_inputs.push(shadow_arg);
614-
new_inputs.push(name);
627+
is_placeholder: false,
628+
};
629+
d_inputs.push(shadow_arg);
630+
new_inputs.push(name);
631+
}
632+
_ => {}
615633
}
616634
}
617635
d_decl.inputs = d_inputs.into();
@@ -630,15 +648,31 @@ fn gen_enzyme_decl(
630648
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
631649
d_decl.output = FnRetTy::Ty(ty);
632650
}
651+
if let DiffActivity::DualOnly = x.ret_activity {
652+
// No need to change the return type,
653+
// we will just return the shadow in place
654+
// of the primal return.
655+
}
633656
}
634657

658+
// If we use ActiveOnly, drop the original return value.
659+
d_decl.output = if active_only_ret {
660+
FnRetTy::Default(span)
661+
} else {
662+
d_decl.output.clone()
663+
};
664+
665+
trace!("act_ret: {:?}", act_ret);
666+
635667
// If we have an active input scalar, add it's gradient to the
636668
// return type. This might require changing the return type to a
637669
// tuple.
638670
if act_ret.len() > 0 {
639671
let ret_ty = match d_decl.output {
640672
FnRetTy::Ty(ref ty) => {
641-
act_ret.insert(0, ty.clone());
673+
if !active_only_ret {
674+
act_ret.insert(0, ty.clone());
675+
}
642676
let kind = TyKind::Tup(act_ret);
643677
P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
644678
}
@@ -655,5 +689,6 @@ fn gen_enzyme_decl(
655689
}
656690

657691
let d_sig = FnSig { header: sig.header.clone(), decl: d_decl, span };
692+
trace!("Generated signature: {:?}", d_sig);
658693
(d_sig, new_inputs, idents)
659694
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
940940
let (primary_ret, ret_activity) = match ret_activity {
941941
DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT),
942942
DiffActivity::Active => (true, CDIFFE_TYPE::DFT_OUT_DIFF),
943+
DiffActivity::ActiveOnly => (false, CDIFFE_TYPE::DFT_OUT_DIFF),
943944
DiffActivity::None => (false, CDIFFE_TYPE::DFT_CONSTANT),
944945
_ => panic!("Invalid return activity"),
945946
};

0 commit comments

Comments
 (0)