@@ -111,7 +111,6 @@ pub fn expand(
111
111
( sig. clone ( ) , true )
112
112
} ,
113
113
_ => {
114
- dbg ! ( & item) ;
115
114
ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
116
115
return vec ! [ item] ;
117
116
}
@@ -280,7 +279,6 @@ pub fn expand(
280
279
d_fn. vis = vis;
281
280
Annotatable :: Item ( d_fn)
282
281
} ;
283
- trace ! ( "Generated function: {:?}" , d_annotatable) ;
284
282
285
283
return vec ! [ orig_annotatable, d_annotatable] ;
286
284
}
@@ -371,7 +369,9 @@ fn gen_enzyme_body(
371
369
return body;
372
370
}
373
371
374
- let primal_ret = sig. decl . output . has_ret ( ) ;
372
+ // having an active-only return means we'll drop the original return type.
373
+ // So that can be treated identical to not having one in the first place.
374
+ let primal_ret = sig. decl . output . has_ret ( ) && !x. has_active_only_ret ( ) ;
375
375
376
376
if primal_ret && n_active == 0 && is_rev ( x. mode ) {
377
377
// We only have the primal ret.
@@ -405,16 +405,26 @@ fn gen_enzyme_body(
405
405
406
406
// Now construct default placeholder for each active float.
407
407
// Is there something nicer than f32::default() and f64::default()?
408
- let mut d_ret_ty = match d_sig. decl . output {
408
+ let d_ret_ty = match d_sig. decl . output {
409
409
FnRetTy :: Ty ( ref ty) => ty. clone ( ) ,
410
410
FnRetTy :: Default ( span) => {
411
411
panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
412
412
}
413
413
} ;
414
- let mut d_ret_ty = match d_ret_ty. kind {
415
- TyKind :: Tup ( ref mut tys) => {
414
+ let mut d_ret_ty = match d_ret_ty. kind . clone ( ) {
415
+ TyKind :: Tup ( ref tys) => {
416
416
tys. clone ( )
417
417
}
418
+ TyKind :: Path ( _, rustc_ast:: Path { segments, .. } ) => {
419
+ if segments. len ( ) == 1 && segments[ 0 ] . args . is_none ( ) {
420
+ let id = vec ! [ segments[ 0 ] . ident] ;
421
+ let kind = TyKind :: Path ( None , ecx. path ( span, id) ) ;
422
+ let ty = P ( rustc_ast:: Ty { kind, id : ast:: DUMMY_NODE_ID , span, tokens : None } ) ;
423
+ thin_vec ! [ ty]
424
+ } else {
425
+ panic ! ( "Expected tuple or simple path return type" ) ;
426
+ }
427
+ }
418
428
_ => {
419
429
// We messed up construction of d_sig
420
430
panic ! ( "Did not expect non-tuple ret ty: {:?}" , d_ret_ty) ;
@@ -585,33 +595,41 @@ fn gen_enzyme_decl(
585
595
}
586
596
}
587
597
598
+ let active_only_ret = x. ret_activity == DiffActivity :: ActiveOnly ;
599
+ if active_only_ret {
600
+ assert ! ( is_rev( x. mode) ) ;
601
+ }
602
+
588
603
// If we return a scalar in the primal and the scalar is active,
589
604
// then add it as last arg to the inputs.
590
605
if is_rev ( x. mode ) {
591
- if let DiffActivity :: Active = x. ret_activity {
592
- let ty = match d_decl. output {
593
- FnRetTy :: Ty ( ref ty) => ty. clone ( ) ,
594
- FnRetTy :: Default ( span) => {
595
- panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
596
- }
597
- } ;
598
- let name = "dret" . to_string ( ) ;
599
- let ident = Ident :: from_str_and_span ( & name, ty. span ) ;
600
- let shadow_arg = ast:: Param {
601
- attrs : ThinVec :: new ( ) ,
602
- ty : ty. clone ( ) ,
603
- pat : P ( ast:: Pat {
606
+ match x. ret_activity {
607
+ DiffActivity :: Active | DiffActivity :: ActiveOnly => {
608
+ let ty = match d_decl. output {
609
+ FnRetTy :: Ty ( ref ty) => ty. clone ( ) ,
610
+ FnRetTy :: Default ( span) => {
611
+ panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
612
+ }
613
+ } ;
614
+ let name = "dret" . to_string ( ) ;
615
+ let ident = Ident :: from_str_and_span ( & name, ty. span ) ;
616
+ let shadow_arg = ast:: Param {
617
+ attrs : ThinVec :: new ( ) ,
618
+ ty : ty. clone ( ) ,
619
+ pat : P ( ast:: Pat {
620
+ id : ast:: DUMMY_NODE_ID ,
621
+ kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
622
+ span : ty. span ,
623
+ tokens : None ,
624
+ } ) ,
604
625
id : ast:: DUMMY_NODE_ID ,
605
- kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
606
626
span : ty. span ,
607
- tokens : None ,
608
- } ) ,
609
- id : ast:: DUMMY_NODE_ID ,
610
- span : ty. span ,
611
- is_placeholder : false ,
612
- } ;
613
- d_inputs. push ( shadow_arg) ;
614
- new_inputs. push ( name) ;
627
+ is_placeholder : false ,
628
+ } ;
629
+ d_inputs. push ( shadow_arg) ;
630
+ new_inputs. push ( name) ;
631
+ }
632
+ _ => { }
615
633
}
616
634
}
617
635
d_decl. inputs = d_inputs. into ( ) ;
@@ -630,15 +648,31 @@ fn gen_enzyme_decl(
630
648
let ty = P ( rustc_ast:: Ty { kind, id : ty. id , span : ty. span , tokens : None } ) ;
631
649
d_decl. output = FnRetTy :: Ty ( ty) ;
632
650
}
651
+ if let DiffActivity :: DualOnly = x. ret_activity {
652
+ // No need to change the return type,
653
+ // we will just return the shadow in place
654
+ // of the primal return.
655
+ }
633
656
}
634
657
658
+ // If we use ActiveOnly, drop the original return value.
659
+ d_decl. output = if active_only_ret {
660
+ FnRetTy :: Default ( span)
661
+ } else {
662
+ d_decl. output . clone ( )
663
+ } ;
664
+
665
+ trace ! ( "act_ret: {:?}" , act_ret) ;
666
+
635
667
// If we have an active input scalar, add it's gradient to the
636
668
// return type. This might require changing the return type to a
637
669
// tuple.
638
670
if act_ret. len ( ) > 0 {
639
671
let ret_ty = match d_decl. output {
640
672
FnRetTy :: Ty ( ref ty) => {
641
- act_ret. insert ( 0 , ty. clone ( ) ) ;
673
+ if !active_only_ret {
674
+ act_ret. insert ( 0 , ty. clone ( ) ) ;
675
+ }
642
676
let kind = TyKind :: Tup ( act_ret) ;
643
677
P ( rustc_ast:: Ty { kind, id : ty. id , span : ty. span , tokens : None } )
644
678
}
@@ -655,5 +689,6 @@ fn gen_enzyme_decl(
655
689
}
656
690
657
691
let d_sig = FnSig { header : sig. header . clone ( ) , decl : d_decl, span } ;
692
+ trace ! ( "Generated signature: {:?}" , d_sig) ;
658
693
( d_sig, new_inputs, idents)
659
694
}
0 commit comments