@@ -925,13 +925,17 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
925
925
input_tts : Vec < TypeTree > ,
926
926
output_tt : TypeTree ,
927
927
) -> & Value {
928
- let ( primary_ret, diff_ret, ret_activity) = match ret_activity {
929
- DiffActivity :: Const => ( true , false , CDIFFE_TYPE :: DFT_CONSTANT ) ,
930
- DiffActivity :: Active => ( true , true , CDIFFE_TYPE :: DFT_DUP_ARG ) ,
931
- DiffActivity :: ActiveOnly => ( false , true , CDIFFE_TYPE :: DFT_DUP_NONEED ) ,
932
- DiffActivity :: None => ( false , false , CDIFFE_TYPE :: DFT_CONSTANT ) ,
928
+ let ( primary_ret, ret_activity) = match ret_activity {
929
+ DiffActivity :: Const => ( true , CDIFFE_TYPE :: DFT_CONSTANT ) ,
930
+ DiffActivity :: Active => ( true , CDIFFE_TYPE :: DFT_DUP_ARG ) ,
931
+ DiffActivity :: None => ( false , CDIFFE_TYPE :: DFT_CONSTANT ) ,
933
932
_ => panic ! ( "Invalid return activity" ) ,
934
933
} ;
934
+ // This only is needed for split-mode AD, which we don't support.
935
+ // See Julia:
936
+ // https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3132
937
+ // https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3092
938
+ let diff_ret = false ;
935
939
936
940
let input_activity: Vec < CDIFFE_TYPE > = input_activity. iter ( ) . map ( |& x| cdiffe_from ( x) ) . collect ( ) ;
937
941
@@ -2690,7 +2694,7 @@ extern "C" {
2690
2694
numRules : size_t ,
2691
2695
) -> EnzymeTypeAnalysisRef ;
2692
2696
}
2693
- pub fn ClearTypeAnalysis ( arg1 : EnzymeTypeAnalysisRef ) { unimplemented ! ( ) }
2697
+ // pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() }
2694
2698
pub fn FreeTypeAnalysis ( arg1 : EnzymeTypeAnalysisRef ) { unimplemented ! ( ) }
2695
2699
pub fn CreateEnzymeLogic ( PostOpt : u8 ) -> EnzymeLogicRef { unimplemented ! ( ) }
2696
2700
pub fn ClearEnzymeLogic ( arg1 : EnzymeLogicRef ) { unimplemented ! ( ) }
@@ -2936,25 +2940,12 @@ pub use self::Enzyme_AD::*;
2936
2940
pub mod Enzyme_AD {
2937
2941
use super :: * ;
2938
2942
2939
- use super :: debuginfo:: {
2940
- DIArray , DIBasicType , DIBuilder , DICompositeType , DIDerivedType , DIDescriptor , DIEnumerator ,
2941
- DIFile , DIFlags , DIGlobalVariableExpression , DILexicalBlock , DILocation , DINameSpace ,
2942
- DISPFlags , DIScope , DISubprogram , DISubrange , DITemplateTypeParameter , DIType , DIVariable ,
2943
- DebugEmissionKind , DebugNameTableKind ,
2944
- } ;
2945
-
2946
- use libc:: { c_char, c_int, c_uint, size_t} ;
2947
- use libc:: { c_ulonglong, c_void} ;
2948
-
2949
- use std:: marker:: PhantomData ;
2950
-
2951
- use super :: RustString ;
2952
- use core:: fmt;
2953
- use std:: ffi:: { CStr , CString } ;
2943
+ use libc:: { c_char, size_t} ;
2944
+ use libc:: c_void;
2954
2945
2955
2946
extern "C" {
2956
- fn EnzymeNewTypeTree ( ) -> CTypeTreeRef ;
2957
- fn EnzymeFreeTypeTree ( CTT : CTypeTreeRef ) ;
2947
+ pub fn EnzymeNewTypeTree ( ) -> CTypeTreeRef ;
2948
+ pub fn EnzymeFreeTypeTree ( CTT : CTypeTreeRef ) ;
2958
2949
pub fn EnzymeSetCLBool ( arg1 : * mut :: std:: os:: raw:: c_void , arg2 : u8 ) ;
2959
2950
pub fn EnzymeSetCLInteger ( arg1 : * mut :: std:: os:: raw:: c_void , arg2 : i64 ) ;
2960
2951
}
@@ -2971,43 +2962,46 @@ extern "C" {
2971
2962
static mut EnzymeStrictAliasing : c_void ;
2972
2963
}
2973
2964
pub fn set_max_int_offset ( offset : u64 ) {
2965
+ let offset = offset. try_into ( ) . unwrap ( ) ;
2974
2966
unsafe {
2975
- EnzymeSetCLInteger ( std:: ptr:: addr_of_mut!( llvm :: MaxIntOffset ) , offset) ;
2967
+ EnzymeSetCLInteger ( std:: ptr:: addr_of_mut!( MaxIntOffset ) , offset) ;
2976
2968
}
2977
2969
}
2978
2970
pub fn set_max_type_offset ( offset : u64 ) {
2971
+ let offset = offset. try_into ( ) . unwrap ( ) ;
2979
2972
unsafe {
2980
- EnzymeSetCLInteger ( std:: ptr:: addr_of_mut!( llvm :: MaxTypeOffset ) , offset) ;
2973
+ EnzymeSetCLInteger ( std:: ptr:: addr_of_mut!( MaxTypeOffset ) , offset) ;
2981
2974
}
2982
2975
}
2983
2976
pub fn set_max_type_depth ( depth : u64 ) {
2977
+ let depth = depth. try_into ( ) . unwrap ( ) ;
2984
2978
unsafe {
2985
- EnzymeSetCLInteger ( std:: ptr:: addr_of_mut!( llvm :: EnzymeMaxTypeDepth ) , depth) ;
2979
+ EnzymeSetCLInteger ( std:: ptr:: addr_of_mut!( EnzymeMaxTypeDepth ) , depth) ;
2986
2980
}
2987
2981
}
2988
2982
pub fn set_print_perf ( print : bool ) {
2989
2983
unsafe {
2990
- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( llvm :: EnzymePrintPerf ) , print as u8 ) ;
2984
+ EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymePrintPerf ) , print as u8 ) ;
2991
2985
}
2992
2986
}
2993
2987
pub fn set_print_activity ( print : bool ) {
2994
2988
unsafe {
2995
- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( llvm :: EnzymePrintActivity ) , print as u8 ) ;
2989
+ EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymePrintActivity ) , print as u8 ) ;
2996
2990
}
2997
2991
}
2998
2992
pub fn set_print_type ( print : bool ) {
2999
2993
unsafe {
3000
- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( llvm :: EnzymePrintType ) , print as u8 ) ;
2994
+ EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymePrintType ) , print as u8 ) ;
3001
2995
}
3002
2996
}
3003
2997
pub fn set_print ( print : bool ) {
3004
2998
unsafe {
3005
- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( llvm :: EnzymePrint ) , print as u8 ) ;
2999
+ EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymePrint ) , print as u8 ) ;
3006
3000
}
3007
3001
}
3008
3002
pub fn set_strict_aliasing ( strict : bool ) {
3009
3003
unsafe {
3010
- EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( llvm :: EnzymeStrictAliasing ) , strict as u8 ) ;
3004
+ EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( EnzymeStrictAliasing ) , strict as u8 ) ;
3011
3005
}
3012
3006
}
3013
3007
extern "C" {
@@ -3074,28 +3068,28 @@ extern "C" {
3074
3068
numRules : size_t ,
3075
3069
) -> EnzymeTypeAnalysisRef ;
3076
3070
}
3077
- // extern "C" {
3078
- // pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef);
3079
- // pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef);
3080
- // pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef;
3081
- // pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef);
3082
- // pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef);
3083
- // }
3071
+ extern "C" {
3072
+ // pub(super) fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef);
3073
+ pub fn FreeTypeAnalysis ( arg1 : EnzymeTypeAnalysisRef ) ;
3074
+ pub fn CreateEnzymeLogic ( PostOpt : u8 ) -> EnzymeLogicRef ;
3075
+ pub fn ClearEnzymeLogic ( arg1 : EnzymeLogicRef ) ;
3076
+ pub fn FreeEnzymeLogic ( arg1 : EnzymeLogicRef ) ;
3077
+ }
3084
3078
3085
3079
extern "C" {
3086
- fn EnzymeNewTypeTreeCT ( arg1 : CConcreteType , ctx : & Context ) -> CTypeTreeRef ;
3087
- fn EnzymeNewTypeTreeTR ( arg1 : CTypeTreeRef ) -> CTypeTreeRef ;
3088
- fn EnzymeMergeTypeTree ( arg1 : CTypeTreeRef , arg2 : CTypeTreeRef ) -> bool ;
3089
- fn EnzymeTypeTreeOnlyEq ( arg1 : CTypeTreeRef , pos : i64 ) ;
3090
- fn EnzymeTypeTreeData0Eq ( arg1 : CTypeTreeRef ) ;
3091
- fn EnzymeTypeTreeShiftIndiciesEq (
3080
+ pub ( super ) fn EnzymeNewTypeTreeCT ( arg1 : CConcreteType , ctx : & Context ) -> CTypeTreeRef ;
3081
+ pub ( super ) fn EnzymeNewTypeTreeTR ( arg1 : CTypeTreeRef ) -> CTypeTreeRef ;
3082
+ pub ( super ) fn EnzymeMergeTypeTree ( arg1 : CTypeTreeRef , arg2 : CTypeTreeRef ) -> bool ;
3083
+ pub ( super ) fn EnzymeTypeTreeOnlyEq ( arg1 : CTypeTreeRef , pos : i64 ) ;
3084
+ pub ( super ) fn EnzymeTypeTreeData0Eq ( arg1 : CTypeTreeRef ) ;
3085
+ pub ( super ) fn EnzymeTypeTreeShiftIndiciesEq (
3092
3086
arg1 : CTypeTreeRef ,
3093
3087
data_layout : * const c_char ,
3094
3088
offset : i64 ,
3095
3089
max_size : i64 ,
3096
3090
add_offset : u64 ,
3097
3091
) ;
3098
- fn EnzymeTypeTreeToStringFree ( arg1 : * const c_char ) ;
3099
- fn EnzymeTypeTreeToString ( arg1 : CTypeTreeRef ) -> * const c_char ;
3092
+ pub ( super ) fn EnzymeTypeTreeToStringFree ( arg1 : * const c_char ) ;
3093
+ pub ( super ) fn EnzymeTypeTreeToString ( arg1 : CTypeTreeRef ) -> * const c_char ;
3100
3094
}
3101
3095
}
0 commit comments