@@ -40,14 +40,13 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode};
40
40
use rustc_ast:: expand:: typetree:: FncTree ;
41
41
use rustc_data_structures:: fx:: FxHashMap ;
42
42
43
-
44
43
use std:: ffi:: { CStr , CString } ;
45
44
use std:: io:: { self , Write } ;
46
45
use std:: path:: { Path , PathBuf } ;
47
46
use std:: sync:: Arc ;
48
47
use std:: { fs, slice, str} ;
49
48
50
- use libc:: { c_char, c_int, c_void, size_t} ;
49
+ use libc:: { c_char, c_int, c_uint , c_void, size_t} ;
51
50
use llvm:: {
52
51
LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
53
52
} ;
@@ -86,6 +85,8 @@ use crate::llvm::{self, DiagnosticInfo, PassManager};
86
85
use crate :: type_:: Type ;
87
86
use crate :: { base, common, llvm_util, LlvmCodegenBackend , ModuleLlvm } ;
88
87
88
+ use tracing:: trace;
89
+
89
90
pub fn llvm_err < ' a > ( dcx : DiagCtxtHandle < ' _ > , err : LlvmError < ' a > ) -> FatalError {
90
91
match llvm:: last_error ( ) {
91
92
Some ( llvm_err) => dcx. emit_almost_fatal ( WithLlvmError ( err, llvm_err) ) ,
@@ -723,6 +724,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
723
724
724
725
unsafe fn create_call < ' a > ( tgt : & ' a Value , src : & ' a Value , rev_mode : bool ,
725
726
llmod : & ' a llvm:: Module , llcx : & llvm:: Context , size_positions : & [ usize ] , ad : & [ AutoDiff ] ) {
727
+ unsafe {
726
728
727
729
// first, remove all calls from fnc
728
730
let bb = LLVMGetFirstBasicBlock ( tgt) ;
@@ -890,14 +892,15 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
890
892
LLVMDisposeBuilder ( builder) ;
891
893
let _fnc_ok =
892
894
LLVMVerifyFunction ( tgt, llvm:: LLVMVerifierFailureAction :: LLVMAbortProcessAction ) ;
895
+ }
893
896
}
894
897
895
898
unsafe fn get_panic_name ( llmod : & llvm:: Module ) -> CString {
896
899
// The names are mangled and their ending changes based on a hash, so just take whichever.
897
- let mut f = LLVMGetFirstFunction ( llmod) ;
900
+ let mut f = unsafe { LLVMGetFirstFunction ( llmod) } ;
898
901
loop {
899
902
if let Some ( lf) = f {
900
- f = LLVMGetNextFunction ( lf) ;
903
+ f = unsafe { LLVMGetNextFunction ( lf) } ;
901
904
let fnc_name = llvm:: get_value_name ( lf) ;
902
905
let fnc_name: String = String :: from_utf8 ( fnc_name. to_vec ( ) ) . unwrap ( ) ;
903
906
if fnc_name. starts_with ( "_ZN4core9panicking14panic_explicit" ) {
@@ -919,6 +922,7 @@ unsafe fn get_panic_name(llmod: &llvm::Module) -> CString {
919
922
// TODO: Pick a panic function which allows displaying an errormessage.
920
923
// TODO: We probably want to keep a handle at higher level and pass it down instead of searching.
921
924
unsafe fn add_panic_msg_to_global < ' a > ( llmod : & ' a llvm:: Module , llcx : & ' a llvm:: Context ) -> & ' a llvm:: Value {
925
+ unsafe {
922
926
use llvm:: * ;
923
927
924
928
// Convert the message to a CString
@@ -957,7 +961,9 @@ unsafe fn add_panic_msg_to_global<'a>(llmod: &'a llvm::Module, llcx: &'a llvm::C
957
961
LLVMSetInitializer ( global_var, struct_initializer) ;
958
962
959
963
global_var
964
+ }
960
965
}
966
+ use rustc_errors:: DiagCtxt ;
961
967
962
968
// As unsafe as it can be.
963
969
#[ allow( unused_variables) ]
@@ -980,12 +986,12 @@ pub(crate) unsafe fn enzyme_ad(
980
986
// get target and source function
981
987
let name = CString :: new ( rust_name. to_owned ( ) ) . unwrap ( ) ;
982
988
let name2 = CString :: new ( rust_name2. clone ( ) ) . unwrap ( ) ;
983
- let src_fnc_opt = llvm:: LLVMGetNamedFunction ( llmod, name. as_c_str ( ) . as_ptr ( ) ) ;
989
+ let src_fnc_opt = unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_c_str ( ) . as_ptr ( ) ) } ;
984
990
let src_fnc = match src_fnc_opt {
985
991
Some ( x) => x,
986
992
None => {
987
993
return Err ( llvm_err (
988
- diag_handler,
994
+ diag_handler. handle ( ) ,
989
995
LlvmError :: PrepareAutoDiff {
990
996
src : rust_name. to_owned ( ) ,
991
997
target : rust_name2. to_owned ( ) ,
@@ -994,12 +1000,12 @@ pub(crate) unsafe fn enzyme_ad(
994
1000
) ) ;
995
1001
}
996
1002
} ;
997
- let target_fnc_opt = llvm:: LLVMGetNamedFunction ( llmod, name2. as_ptr ( ) ) ;
1003
+ let target_fnc_opt = unsafe { llvm:: LLVMGetNamedFunction ( llmod, name2. as_ptr ( ) ) } ;
998
1004
let target_fnc = match target_fnc_opt {
999
1005
Some ( x) => x,
1000
1006
None => {
1001
1007
return Err ( llvm_err (
1002
- diag_handler,
1008
+ diag_handler. handle ( ) ,
1003
1009
LlvmError :: PrepareAutoDiff {
1004
1010
src : rust_name. to_owned ( ) ,
1005
1011
target : rust_name2. to_owned ( ) ,
@@ -1008,8 +1014,8 @@ pub(crate) unsafe fn enzyme_ad(
1008
1014
) ) ;
1009
1015
}
1010
1016
} ;
1011
- let src_num_args = llvm:: LLVMCountParams ( src_fnc) ;
1012
- let target_num_args = llvm:: LLVMCountParams ( target_fnc) ;
1017
+ let src_num_args = unsafe { llvm:: LLVMCountParams ( src_fnc) } ;
1018
+ let target_num_args = unsafe { llvm:: LLVMCountParams ( target_fnc) } ;
1013
1019
// A really simple check
1014
1020
assert ! ( src_num_args <= target_num_args) ;
1015
1021
@@ -1024,7 +1030,7 @@ pub(crate) unsafe fn enzyme_ad(
1024
1030
let output_tt = to_enzyme_typetree ( item. output , llvm_data_layout, llcx) ;
1025
1031
1026
1032
let type_analysis: EnzymeTypeAnalysisRef =
1027
- CreateTypeAnalysis ( logic_ref, std:: ptr:: null_mut ( ) , std:: ptr:: null_mut ( ) , 0 ) ;
1033
+ unsafe { CreateTypeAnalysis ( logic_ref, std:: ptr:: null_mut ( ) , std:: ptr:: null_mut ( ) , 0 ) } ;
1028
1034
1029
1035
llvm:: set_strict_aliasing ( false ) ;
1030
1036
@@ -1049,40 +1055,42 @@ pub(crate) unsafe fn enzyme_ad(
1049
1055
_ => unreachable ! ( ) ,
1050
1056
} ;
1051
1057
1052
- let void_type = LLVMVoidTypeInContext ( llcx) ;
1053
- let return_type = LLVMGetReturnType ( LLVMGlobalGetValueType ( src_fnc) ) ;
1054
- let void_ret = void_type == return_type;
1055
- let mut tmp = match mode {
1056
- DiffMode :: Forward => enzyme_rust_forward_diff (
1057
- logic_ref,
1058
- type_analysis,
1059
- src_fnc,
1060
- args_activity,
1061
- ret_activity,
1062
- input_tts,
1063
- output_tt,
1064
- void_ret,
1065
- ) ,
1066
- DiffMode :: Reverse => enzyme_rust_reverse_diff (
1067
- logic_ref,
1068
- type_analysis,
1069
- src_fnc,
1070
- args_activity,
1071
- ret_activity,
1072
- input_tts,
1073
- output_tt,
1074
- ) ,
1075
- _ => unreachable ! ( ) ,
1076
- } ;
1077
- let mut res: & Value = tmp. 0 ;
1078
- let size_positions: Vec < usize > = tmp. 1 ;
1058
+ unsafe {
1059
+ let void_type = LLVMVoidTypeInContext ( llcx) ;
1060
+ let return_type = LLVMGetReturnType ( LLVMGlobalGetValueType ( src_fnc) ) ;
1061
+ let void_ret = void_type == return_type;
1062
+ let mut tmp = match mode {
1063
+ DiffMode :: Forward => enzyme_rust_forward_diff (
1064
+ logic_ref,
1065
+ type_analysis,
1066
+ src_fnc,
1067
+ args_activity,
1068
+ ret_activity,
1069
+ input_tts,
1070
+ output_tt,
1071
+ void_ret,
1072
+ ) ,
1073
+ DiffMode :: Reverse => enzyme_rust_reverse_diff (
1074
+ logic_ref,
1075
+ type_analysis,
1076
+ src_fnc,
1077
+ args_activity,
1078
+ ret_activity,
1079
+ input_tts,
1080
+ output_tt,
1081
+ ) ,
1082
+ _ => unreachable ! ( ) ,
1083
+ } ;
1084
+ let mut res: & Value = tmp. 0 ;
1085
+ let size_positions: Vec < usize > = tmp. 1 ;
1079
1086
1080
- let f_return_type = LLVMGetReturnType ( LLVMGlobalGetValueType ( res) ) ;
1087
+ let f_return_type = LLVMGetReturnType ( LLVMGlobalGetValueType ( res) ) ;
1081
1088
1082
- let rev_mode = item. attrs . mode == DiffMode :: Reverse ;
1083
- create_call ( target_fnc, res, rev_mode, llmod, llcx, & size_positions, ad) ;
1084
- // TODO: implement drop for wrapper type?
1085
- FreeTypeAnalysis ( type_analysis) ;
1089
+ let rev_mode = item. attrs . mode == DiffMode :: Reverse ;
1090
+ create_call ( target_fnc, res, rev_mode, llmod, llcx, & size_positions, ad) ;
1091
+ // TODO: implement drop for wrapper type?
1092
+ FreeTypeAnalysis ( type_analysis) ;
1093
+ }
1086
1094
1087
1095
Ok ( ( ) )
1088
1096
}
@@ -1122,7 +1130,7 @@ pub(crate) unsafe fn differentiate(
1122
1130
ret : item. output . clone ( ) ,
1123
1131
} ;
1124
1132
let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
1125
- let fn_def: & llvm:: Value = llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) . unwrap ( ) ;
1133
+ let fn_def: & llvm:: Value = unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) . unwrap ( ) } ;
1126
1134
crate :: builder:: add_tt2 ( llmod, llcx, fn_def, tt) ;
1127
1135
1128
1136
// Before dumping the module, we also might want to add dummy functions, which will
@@ -1182,10 +1190,10 @@ pub(crate) unsafe fn differentiate(
1182
1190
1183
1191
// If a function is a base for some higher order ad, always optimize
1184
1192
let fnc_opt_base = true ;
1185
- let logic_ref_opt: EnzymeLogicRef = CreateEnzymeLogic ( fnc_opt_base as u8 ) ;
1193
+ let logic_ref_opt: EnzymeLogicRef = unsafe { CreateEnzymeLogic ( fnc_opt_base as u8 ) } ;
1186
1194
1187
1195
for item in first_order_items {
1188
- let res = enzyme_ad ( llmod, llcx, & diag_handler, item, logic_ref_opt, ad) ;
1196
+ let res = unsafe { enzyme_ad ( llmod, llcx, & diag_handler. handle ( ) , item, logic_ref_opt, ad) } ;
1189
1197
assert ! ( res. is_ok( ) ) ;
1190
1198
}
1191
1199
@@ -1196,44 +1204,44 @@ pub(crate) unsafe fn differentiate(
1196
1204
dbg ! ( "Enable extra optimizations for Enzyme" ) ;
1197
1205
logic_ref_opt
1198
1206
}
1199
- false => CreateEnzymeLogic ( fnc_opt as u8 ) ,
1207
+ false => unsafe { CreateEnzymeLogic ( fnc_opt as u8 ) } ,
1200
1208
} ;
1201
1209
for item in higher_order_items {
1202
- let res = enzyme_ad ( llmod, llcx, & diag_handler, item, logic_ref, ad) ;
1210
+ let res = unsafe { enzyme_ad ( llmod, llcx, & diag_handler. handle ( ) , item, logic_ref, ad) } ;
1203
1211
assert ! ( res. is_ok( ) ) ;
1204
1212
}
1205
1213
1206
- let mut f = LLVMGetFirstFunction ( llmod) ;
1207
- loop {
1208
- if let Some ( lf) = f {
1209
- f = LLVMGetNextFunction ( lf) ;
1210
- let myhwattr = "enzyme_hw" ;
1211
- let attr = LLVMGetStringAttributeAtIndex (
1212
- lf,
1213
- c_uint:: MAX ,
1214
- myhwattr. as_ptr ( ) as * const c_char ,
1215
- myhwattr. as_bytes ( ) . len ( ) as c_uint ,
1216
- ) ;
1217
- if LLVMIsStringAttribute ( attr) {
1218
- LLVMRemoveStringAttributeAtIndex (
1214
+ unsafe {
1215
+ let mut f = LLVMGetFirstFunction ( llmod) ;
1216
+ loop {
1217
+ if let Some ( lf) = f {
1218
+ f = LLVMGetNextFunction ( lf) ;
1219
+ let myhwattr = "enzyme_hw" ;
1220
+ let attr = LLVMGetStringAttributeAtIndex (
1219
1221
lf,
1220
1222
c_uint:: MAX ,
1221
1223
myhwattr. as_ptr ( ) as * const c_char ,
1222
1224
myhwattr. as_bytes ( ) . len ( ) as c_uint ,
1223
1225
) ;
1226
+ if LLVMIsStringAttribute ( attr) {
1227
+ LLVMRemoveStringAttributeAtIndex (
1228
+ lf,
1229
+ c_uint:: MAX ,
1230
+ myhwattr. as_ptr ( ) as * const c_char ,
1231
+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
1232
+ ) ;
1233
+ } else {
1234
+ LLVMRustRemoveEnumAttributeAtIndex (
1235
+ lf,
1236
+ c_uint:: MAX ,
1237
+ AttributeKind :: SanitizeHWAddress ,
1238
+ ) ;
1239
+ }
1224
1240
} else {
1225
- LLVMRustRemoveEnumAttributeAtIndex (
1226
- lf,
1227
- c_uint:: MAX ,
1228
- AttributeKind :: SanitizeHWAddress ,
1229
- ) ;
1241
+ break ;
1230
1242
}
1231
- } else {
1232
- break ;
1233
1243
}
1234
- }
1235
- if ad. contains ( & AutoDiff :: PrintModAfterEnzyme ) {
1236
- unsafe {
1244
+ if ad. contains ( & AutoDiff :: PrintModAfterEnzyme ) {
1237
1245
LLVMDumpModule ( llmod) ;
1238
1246
}
1239
1247
}
@@ -1260,7 +1268,7 @@ pub(crate) unsafe fn differentiate(
1260
1268
first_run = true ;
1261
1269
}
1262
1270
let noop = false ;
1263
- llvm_optimize ( cgcx, & diag_handler, module, config, opt_level, opt_stage, first_run, noop) ?;
1271
+ unsafe { llvm_optimize ( cgcx, diag_handler. handle ( ) , module, config, opt_level, opt_stage, first_run, noop) ?} ;
1264
1272
}
1265
1273
if ad. contains ( & AutoDiff :: AltPipeline ) {
1266
1274
dbg ! ( "Running Second postAD optimization" ) ;
@@ -1278,7 +1286,7 @@ pub(crate) unsafe fn differentiate(
1278
1286
first_run = false ;
1279
1287
}
1280
1288
let noop = false ;
1281
- llvm_optimize ( cgcx, & diag_handler, module, config, opt_level, opt_stage, first_run, noop) ?;
1289
+ unsafe { llvm_optimize ( cgcx, diag_handler. handle ( ) , module, config, opt_level, opt_stage, first_run, noop) ?} ;
1282
1290
}
1283
1291
}
1284
1292
}
@@ -1320,7 +1328,7 @@ pub(crate) unsafe fn optimize(
1320
1328
// different code sections. We remove this attribute after Enzyme is done, to not affect the
1321
1329
// rest of the compilation.
1322
1330
// TODO: only enable this code when at least one function gets differentiated.
1323
- {
1331
+ unsafe {
1324
1332
let mut f = LLVMGetFirstFunction ( llmod) ;
1325
1333
loop {
1326
1334
if let Some ( lf) = f {
0 commit comments