@@ -914,6 +914,8 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
914
914
assert ! ( ret_activity == CDIFFE_TYPE :: DFT_CONSTANT || ret_activity == CDIFFE_TYPE :: DFT_OUT_DIFF ) ;
915
915
let input_activity: Vec < CDIFFE_TYPE > = input_activity. iter ( ) . map ( |& x| cdiffe_from ( x) ) . collect ( ) ;
916
916
917
+ dbg ! ( & fnc) ;
918
+
917
919
if ret_activity == CDIFFE_TYPE :: DFT_DUP_ARG {
918
920
if ret_primary_ret != true {
919
921
dbg ! ( "overwriting ret_primary_ret!" ) ;
@@ -931,6 +933,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
931
933
// We don't support volatile / extern / (global?) values.
932
934
// Just because I didn't had time to test them, and it seems less urgent.
933
935
let args_uncacheable = vec ! [ 0 ; input_tts. len( ) ] ;
936
+ assert ! ( args_uncacheable. len( ) == input_activity. len( ) ) ;
937
+ let num_fnc_args = LLVMCountParams ( fnc) ;
938
+ println ! ( "num_fnc_args: {}" , num_fnc_args) ;
939
+ println ! ( "input_activity.len(): {}" , input_activity. len( ) ) ;
940
+ assert ! ( num_fnc_args == input_activity. len( ) as u32 ) ;
934
941
let kv_tmp = IntList { data : std:: ptr:: null_mut ( ) , size : 0 } ;
935
942
936
943
@@ -942,7 +949,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
942
949
KnownValues : known_values. as_mut_ptr ( ) ,
943
950
} ;
944
951
945
- EnzymeCreatePrimalAndGradient (
952
+ let res = EnzymeCreatePrimalAndGradient (
946
953
logic_ref, // Logic
947
954
std:: ptr:: null ( ) ,
948
955
std:: ptr:: null ( ) ,
@@ -963,7 +970,9 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
963
970
args_uncacheable. len ( ) , // uncacheable arguments
964
971
std:: ptr:: null_mut ( ) , // write augmented function to this
965
972
0 ,
966
- )
973
+ ) ;
974
+ dbg ! ( & res) ;
975
+ res
967
976
}
968
977
pub type GetSymbolsCallback = unsafe extern "C" fn ( * mut c_void , * const c_char ) -> * mut c_void ;
969
978
pub type GetSymbolsErrorCallback = unsafe extern "C" fn ( * const c_char ) -> * mut c_void ;
0 commit comments