@@ -698,6 +698,7 @@ pub(crate) unsafe fn extract_return_type<'a>(
698
698
pub ( crate ) unsafe fn enzyme_ad (
699
699
llmod : & llvm:: Module ,
700
700
llcx : & llvm:: Context ,
701
+ diag_handler : & rustc_errors:: Handler ,
701
702
item : AutoDiffItem ,
702
703
) -> Result < ( ) , FatalError > {
703
704
let autodiff_mode = item. attrs . mode ;
@@ -710,8 +711,28 @@ pub(crate) unsafe fn enzyme_ad(
710
711
// get target and source function
711
712
let name = CString :: new ( rust_name. to_owned ( ) ) . unwrap ( ) ;
712
713
let name2 = CString :: new ( rust_name2. clone ( ) ) . unwrap ( ) ;
713
- let src_fnc = llvm:: LLVMGetNamedFunction ( llmod, name. as_c_str ( ) . as_ptr ( ) ) . unwrap ( ) ;
714
- let target_fnc = llvm:: LLVMGetNamedFunction ( llmod, name2. as_ptr ( ) ) . unwrap ( ) ;
714
+ let src_fnc_opt = llvm:: LLVMGetNamedFunction ( llmod, name. as_c_str ( ) . as_ptr ( ) ) ;
715
+ let src_fnc = match src_fnc_opt {
716
+ Some ( x) => x,
717
+ None => {
718
+ return Err ( llvm_err ( diag_handler, LlvmError :: PrepareAutoDiff {
719
+ src : rust_name. to_owned ( ) ,
720
+ target : rust_name2. to_owned ( ) ,
721
+ error : "could not find src function" . to_owned ( ) ,
722
+ } ) ) ;
723
+ }
724
+ } ;
725
+ let target_fnc_opt = llvm:: LLVMGetNamedFunction ( llmod, name2. as_ptr ( ) ) ;
726
+ let target_fnc = match target_fnc_opt {
727
+ Some ( x) => x,
728
+ None => {
729
+ return Err ( llvm_err ( diag_handler, LlvmError :: PrepareAutoDiff {
730
+ src : rust_name. to_owned ( ) ,
731
+ target : rust_name2. to_owned ( ) ,
732
+ error : "could not find target function" . to_owned ( ) ,
733
+ } ) ) ;
734
+ }
735
+ } ;
715
736
let src_num_args = llvm:: LLVMCountParams ( src_fnc) ;
716
737
let target_num_args = llvm:: LLVMCountParams ( target_fnc) ;
717
738
assert ! ( src_num_args <= target_num_args) ;
@@ -791,13 +812,14 @@ pub(crate) unsafe fn enzyme_ad(
791
812
792
813
pub ( crate ) unsafe fn differentiate (
793
814
module : & ModuleCodegen < ModuleLlvm > ,
794
- _cgcx : & CodegenContext < LlvmCodegenBackend > ,
815
+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
795
816
diff_items : Vec < AutoDiffItem > ,
796
817
_typetrees : FxHashMap < String , DiffTypeTree > ,
797
818
_config : & ModuleConfig ,
798
819
) -> Result < ( ) , FatalError > {
799
820
let llmod = module. module_llvm . llmod ( ) ;
800
821
let llcx = & module. module_llvm . llcx ;
822
+ let diag_handler = cgcx. create_diag_handler ( ) ;
801
823
802
824
llvm:: EnzymeSetCLBool ( std:: ptr:: addr_of_mut!( llvm:: EnzymeStrictAliasing ) , 0 ) ;
803
825
@@ -818,7 +840,7 @@ pub(crate) unsafe fn differentiate(
818
840
}
819
841
820
842
for item in diff_items {
821
- let res = enzyme_ad ( llmod, llcx, item) ;
843
+ let res = enzyme_ad ( llmod, llcx, & diag_handler , item) ;
822
844
assert ! ( res. is_ok( ) ) ;
823
845
}
824
846
0 commit comments