@@ -851,10 +851,10 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
851
851
fnc : & Value ,
852
852
input_diffactivity : Vec < DiffActivity > ,
853
853
ret_diffactivity : DiffActivity ,
854
- mut ret_primary_ret : bool ,
855
854
input_tts : Vec < TypeTree > ,
856
855
output_tt : TypeTree ,
857
856
) -> & Value {
857
+ let mut ret_primary_ret = false ;
858
858
let ret_activity = cdiffe_from ( ret_diffactivity) ;
859
859
assert ! ( ret_activity != CDIFFE_TYPE :: DFT_OUT_DIFF ) ;
860
860
let mut input_activity: Vec < CDIFFE_TYPE > = vec ! [ ] ;
@@ -925,29 +925,22 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
925
925
fnc : & Value ,
926
926
input_activity : Vec < DiffActivity > ,
927
927
ret_activity : DiffActivity ,
928
- mut ret_primary_ret : bool ,
929
- diff_primary_ret : bool ,
930
928
input_tts : Vec < TypeTree > ,
931
929
output_tt : TypeTree ,
932
930
) -> & Value {
933
- let ret_activity = cdiffe_from ( ret_activity) ;
934
- assert ! ( ret_activity == CDIFFE_TYPE :: DFT_CONSTANT || ret_activity == CDIFFE_TYPE :: DFT_OUT_DIFF ) ;
931
+ let ( primary_ret, diff_ret, ret_activity) = match ret_activity {
932
+ DiffActivity :: Const => ( true , false , CDIFFE_TYPE :: DFT_CONSTANT ) ,
933
+ DiffActivity :: Active => ( true , true , CDIFFE_TYPE :: DFT_DUP_ARG ) ,
934
+ DiffActivity :: ActiveOnly => ( false , true , CDIFFE_TYPE :: DFT_DUP_NONEED ) ,
935
+ DiffActivity :: None => ( false , false , CDIFFE_TYPE :: DFT_CONSTANT ) ,
936
+ _ => panic ! ( "Invalid return activity" ) ,
937
+ } ;
938
+
939
+ //assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF);
935
940
let input_activity: Vec < CDIFFE_TYPE > = input_activity. iter ( ) . map ( |& x| cdiffe_from ( x) ) . collect ( ) ;
936
941
937
942
dbg ! ( & fnc) ;
938
943
939
- if ret_activity == CDIFFE_TYPE :: DFT_DUP_ARG {
940
- if ret_primary_ret != true {
941
- dbg ! ( "overwriting ret_primary_ret!" ) ;
942
- }
943
- ret_primary_ret = true ;
944
- } else if ret_activity == CDIFFE_TYPE :: DFT_DUP_NONEED {
945
- if ret_primary_ret != false {
946
- dbg ! ( "overwriting ret_primary_ret!" ) ;
947
- }
948
- ret_primary_ret = false ;
949
- }
950
-
951
944
let mut args_tree = input_tts. iter ( ) . map ( |x| x. inner ) . collect :: < Vec < _ > > ( ) ;
952
945
953
946
// We don't support volatile / extern / (global?) values.
@@ -977,8 +970,8 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
977
970
input_activity. as_ptr ( ) ,
978
971
input_activity. len ( ) , // constant arguments
979
972
type_analysis, // type analysis struct
980
- ret_primary_ret as u8 ,
981
- diff_primary_ret as u8 , //0
973
+ primary_ret as u8 ,
974
+ diff_ret as u8 , //0
982
975
CDerivativeMode :: DEM_ReverseModeCombined , // return value, dret_used, top_level which was 1
983
976
1 , // vector mode width
984
977
1 , // free memory
@@ -2704,12 +2697,13 @@ pub enum CDIFFE_TYPE {
2704
2697
fn cdiffe_from ( act : DiffActivity ) -> CDIFFE_TYPE {
2705
2698
return match act {
2706
2699
DiffActivity :: None => CDIFFE_TYPE :: DFT_CONSTANT ,
2707
- DiffActivity :: Active => CDIFFE_TYPE :: DFT_OUT_DIFF ,
2708
2700
DiffActivity :: Const => CDIFFE_TYPE :: DFT_CONSTANT ,
2701
+ DiffActivity :: Active => CDIFFE_TYPE :: DFT_OUT_DIFF ,
2702
+ DiffActivity :: ActiveOnly => CDIFFE_TYPE :: DFT_OUT_DIFF ,
2709
2703
DiffActivity :: Dual => CDIFFE_TYPE :: DFT_DUP_ARG ,
2710
- DiffActivity :: DualNoNeed => CDIFFE_TYPE :: DFT_DUP_NONEED ,
2704
+ DiffActivity :: DualOnly => CDIFFE_TYPE :: DFT_DUP_NONEED ,
2711
2705
DiffActivity :: Duplicated => CDIFFE_TYPE :: DFT_DUP_ARG ,
2712
- DiffActivity :: DuplicatedNoNeed => CDIFFE_TYPE :: DFT_DUP_NONEED ,
2706
+ DiffActivity :: DuplicatedOnly => CDIFFE_TYPE :: DFT_DUP_NONEED ,
2713
2707
} ;
2714
2708
}
2715
2709
0 commit comments