Skip to content

Commit c415723

Browse files
committed
add unsafe in unsafe fixes
1 parent 4082272 commit c415723

File tree

4 files changed

+91
-82
lines changed

4 files changed

+91
-82
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 82 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,13 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode};
4040
use rustc_ast::expand::typetree::FncTree;
4141
use rustc_data_structures::fx::FxHashMap;
4242

43-
4443
use std::ffi::{CStr, CString};
4544
use std::io::{self, Write};
4645
use std::path::{Path, PathBuf};
4746
use std::sync::Arc;
4847
use std::{fs, slice, str};
4948

50-
use libc::{c_char, c_int, c_void, size_t};
49+
use libc::{c_char, c_int, c_uint, c_void, size_t};
5150
use llvm::{
5251
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols,
5352
};
@@ -86,6 +85,8 @@ use crate::llvm::{self, DiagnosticInfo, PassManager};
8685
use crate::type_::Type;
8786
use crate::{base, common, llvm_util, LlvmCodegenBackend, ModuleLlvm};
8887

88+
use tracing::trace;
89+
8990
pub fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> FatalError {
9091
match llvm::last_error() {
9192
Some(llvm_err) => dcx.emit_almost_fatal(WithLlvmError(err, llvm_err)),
@@ -723,6 +724,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
723724

724725
unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
725726
llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize], ad: &[AutoDiff]) {
727+
unsafe {
726728

727729
// first, remove all calls from fnc
728730
let bb = LLVMGetFirstBasicBlock(tgt);
@@ -890,14 +892,15 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
890892
LLVMDisposeBuilder(builder);
891893
let _fnc_ok =
892894
LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction);
895+
}
893896
}
894897

895898
unsafe fn get_panic_name(llmod: &llvm::Module) -> CString {
896899
// The names are mangled and their ending changes based on a hash, so just take whichever.
897-
let mut f = LLVMGetFirstFunction(llmod);
900+
let mut f = unsafe{LLVMGetFirstFunction(llmod)};
898901
loop {
899902
if let Some(lf) = f {
900-
f = LLVMGetNextFunction(lf);
903+
f = unsafe {LLVMGetNextFunction(lf)};
901904
let fnc_name = llvm::get_value_name(lf);
902905
let fnc_name: String = String::from_utf8(fnc_name.to_vec()).unwrap();
903906
if fnc_name.starts_with("_ZN4core9panicking14panic_explicit") {
@@ -919,6 +922,7 @@ unsafe fn get_panic_name(llmod: &llvm::Module) -> CString {
919922
// TODO: Pick a panic function which allows displaying an errormessage.
920923
// TODO: We probably want to keep a handle at higher level and pass it down instead of searching.
921924
unsafe fn add_panic_msg_to_global<'a>(llmod: &'a llvm::Module, llcx: &'a llvm::Context) -> &'a llvm::Value {
925+
unsafe {
922926
use llvm::*;
923927

924928
// Convert the message to a CString
@@ -957,7 +961,9 @@ unsafe fn add_panic_msg_to_global<'a>(llmod: &'a llvm::Module, llcx: &'a llvm::C
957961
LLVMSetInitializer(global_var, struct_initializer);
958962

959963
global_var
964+
}
960965
}
966+
use rustc_errors::DiagCtxt;
961967

962968
// As unsafe as it can be.
963969
#[allow(unused_variables)]
@@ -980,12 +986,12 @@ pub(crate) unsafe fn enzyme_ad(
980986
// get target and source function
981987
let name = CString::new(rust_name.to_owned()).unwrap();
982988
let name2 = CString::new(rust_name2.clone()).unwrap();
983-
let src_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr());
989+
let src_fnc_opt = unsafe {llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr())};
984990
let src_fnc = match src_fnc_opt {
985991
Some(x) => x,
986992
None => {
987993
return Err(llvm_err(
988-
diag_handler,
994+
diag_handler.handle(),
989995
LlvmError::PrepareAutoDiff {
990996
src: rust_name.to_owned(),
991997
target: rust_name2.to_owned(),
@@ -994,12 +1000,12 @@ pub(crate) unsafe fn enzyme_ad(
9941000
));
9951001
}
9961002
};
997-
let target_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr());
1003+
let target_fnc_opt = unsafe {llvm::LLVMGetNamedFunction(llmod, name2.as_ptr())};
9981004
let target_fnc = match target_fnc_opt {
9991005
Some(x) => x,
10001006
None => {
10011007
return Err(llvm_err(
1002-
diag_handler,
1008+
diag_handler.handle(),
10031009
LlvmError::PrepareAutoDiff {
10041010
src: rust_name.to_owned(),
10051011
target: rust_name2.to_owned(),
@@ -1008,8 +1014,8 @@ pub(crate) unsafe fn enzyme_ad(
10081014
));
10091015
}
10101016
};
1011-
let src_num_args = llvm::LLVMCountParams(src_fnc);
1012-
let target_num_args = llvm::LLVMCountParams(target_fnc);
1017+
let src_num_args = unsafe {llvm::LLVMCountParams(src_fnc)};
1018+
let target_num_args = unsafe {llvm::LLVMCountParams(target_fnc)};
10131019
// A really simple check
10141020
assert!(src_num_args <= target_num_args);
10151021

@@ -1024,7 +1030,7 @@ pub(crate) unsafe fn enzyme_ad(
10241030
let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx);
10251031

10261032
let type_analysis: EnzymeTypeAnalysisRef =
1027-
CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0);
1033+
unsafe {CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0)};
10281034

10291035
llvm::set_strict_aliasing(false);
10301036

@@ -1049,40 +1055,42 @@ pub(crate) unsafe fn enzyme_ad(
10491055
_ => unreachable!(),
10501056
};
10511057

1052-
let void_type = LLVMVoidTypeInContext(llcx);
1053-
let return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src_fnc));
1054-
let void_ret = void_type == return_type;
1055-
let mut tmp = match mode {
1056-
DiffMode::Forward => enzyme_rust_forward_diff(
1057-
logic_ref,
1058-
type_analysis,
1059-
src_fnc,
1060-
args_activity,
1061-
ret_activity,
1062-
input_tts,
1063-
output_tt,
1064-
void_ret,
1065-
),
1066-
DiffMode::Reverse => enzyme_rust_reverse_diff(
1067-
logic_ref,
1068-
type_analysis,
1069-
src_fnc,
1070-
args_activity,
1071-
ret_activity,
1072-
input_tts,
1073-
output_tt,
1074-
),
1075-
_ => unreachable!(),
1076-
};
1077-
let mut res: &Value = tmp.0;
1078-
let size_positions: Vec<usize> = tmp.1;
1058+
unsafe {
1059+
let void_type = LLVMVoidTypeInContext(llcx);
1060+
let return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src_fnc));
1061+
let void_ret = void_type == return_type;
1062+
let mut tmp = match mode {
1063+
DiffMode::Forward => enzyme_rust_forward_diff(
1064+
logic_ref,
1065+
type_analysis,
1066+
src_fnc,
1067+
args_activity,
1068+
ret_activity,
1069+
input_tts,
1070+
output_tt,
1071+
void_ret,
1072+
),
1073+
DiffMode::Reverse => enzyme_rust_reverse_diff(
1074+
logic_ref,
1075+
type_analysis,
1076+
src_fnc,
1077+
args_activity,
1078+
ret_activity,
1079+
input_tts,
1080+
output_tt,
1081+
),
1082+
_ => unreachable!(),
1083+
};
1084+
let mut res: &Value = tmp.0;
1085+
let size_positions: Vec<usize> = tmp.1;
10791086

1080-
let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res));
1087+
let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res));
10811088

1082-
let rev_mode = item.attrs.mode == DiffMode::Reverse;
1083-
create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions, ad);
1084-
// TODO: implement drop for wrapper type?
1085-
FreeTypeAnalysis(type_analysis);
1089+
let rev_mode = item.attrs.mode == DiffMode::Reverse;
1090+
create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions, ad);
1091+
// TODO: implement drop for wrapper type?
1092+
FreeTypeAnalysis(type_analysis);
1093+
}
10861094

10871095
Ok(())
10881096
}
@@ -1122,7 +1130,7 @@ pub(crate) unsafe fn differentiate(
11221130
ret: item.output.clone(),
11231131
};
11241132
let name = CString::new(item.source.clone()).unwrap();
1125-
let fn_def: &llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap();
1133+
let fn_def: &llvm::Value = unsafe {llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap()};
11261134
crate::builder::add_tt2(llmod, llcx, fn_def, tt);
11271135

11281136
// Before dumping the module, we also might want to add dummy functions, which will
@@ -1182,10 +1190,10 @@ pub(crate) unsafe fn differentiate(
11821190

11831191
// If a function is a base for some higher order ad, always optimize
11841192
let fnc_opt_base = true;
1185-
let logic_ref_opt: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt_base as u8);
1193+
let logic_ref_opt: EnzymeLogicRef = unsafe {CreateEnzymeLogic(fnc_opt_base as u8)};
11861194

11871195
for item in first_order_items {
1188-
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt, ad);
1196+
let res = unsafe{enzyme_ad(llmod, llcx, &diag_handler.handle(), item, logic_ref_opt, ad)};
11891197
assert!(res.is_ok());
11901198
}
11911199

@@ -1196,44 +1204,44 @@ pub(crate) unsafe fn differentiate(
11961204
dbg!("Enable extra optimizations for Enzyme");
11971205
logic_ref_opt
11981206
}
1199-
false => CreateEnzymeLogic(fnc_opt as u8),
1207+
false => unsafe {CreateEnzymeLogic(fnc_opt as u8)},
12001208
};
12011209
for item in higher_order_items {
1202-
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref, ad);
1210+
let res = unsafe{ enzyme_ad(llmod, llcx, &diag_handler.handle(), item, logic_ref, ad)};
12031211
assert!(res.is_ok());
12041212
}
12051213

1206-
let mut f = LLVMGetFirstFunction(llmod);
1207-
loop {
1208-
if let Some(lf) = f {
1209-
f = LLVMGetNextFunction(lf);
1210-
let myhwattr = "enzyme_hw";
1211-
let attr = LLVMGetStringAttributeAtIndex(
1212-
lf,
1213-
c_uint::MAX,
1214-
myhwattr.as_ptr() as *const c_char,
1215-
myhwattr.as_bytes().len() as c_uint,
1216-
);
1217-
if LLVMIsStringAttribute(attr) {
1218-
LLVMRemoveStringAttributeAtIndex(
1214+
unsafe {
1215+
let mut f = LLVMGetFirstFunction(llmod);
1216+
loop {
1217+
if let Some(lf) = f {
1218+
f = LLVMGetNextFunction(lf);
1219+
let myhwattr = "enzyme_hw";
1220+
let attr = LLVMGetStringAttributeAtIndex(
12191221
lf,
12201222
c_uint::MAX,
12211223
myhwattr.as_ptr() as *const c_char,
12221224
myhwattr.as_bytes().len() as c_uint,
12231225
);
1226+
if LLVMIsStringAttribute(attr) {
1227+
LLVMRemoveStringAttributeAtIndex(
1228+
lf,
1229+
c_uint::MAX,
1230+
myhwattr.as_ptr() as *const c_char,
1231+
myhwattr.as_bytes().len() as c_uint,
1232+
);
1233+
} else {
1234+
LLVMRustRemoveEnumAttributeAtIndex(
1235+
lf,
1236+
c_uint::MAX,
1237+
AttributeKind::SanitizeHWAddress,
1238+
);
1239+
}
12241240
} else {
1225-
LLVMRustRemoveEnumAttributeAtIndex(
1226-
lf,
1227-
c_uint::MAX,
1228-
AttributeKind::SanitizeHWAddress,
1229-
);
1241+
break;
12301242
}
1231-
} else {
1232-
break;
12331243
}
1234-
}
1235-
if ad.contains(&AutoDiff::PrintModAfterEnzyme) {
1236-
unsafe {
1244+
if ad.contains(&AutoDiff::PrintModAfterEnzyme) {
12371245
LLVMDumpModule(llmod);
12381246
}
12391247
}
@@ -1260,7 +1268,7 @@ pub(crate) unsafe fn differentiate(
12601268
first_run = true;
12611269
}
12621270
let noop = false;
1263-
llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?;
1271+
unsafe {llvm_optimize(cgcx, diag_handler.handle(), module, config, opt_level, opt_stage, first_run, noop)?};
12641272
}
12651273
if ad.contains(&AutoDiff::AltPipeline) {
12661274
dbg!("Running Second postAD optimization");
@@ -1278,7 +1286,7 @@ pub(crate) unsafe fn differentiate(
12781286
first_run = false;
12791287
}
12801288
let noop = false;
1281-
llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?;
1289+
unsafe {llvm_optimize(cgcx, diag_handler.handle(), module, config, opt_level, opt_stage, first_run, noop)?};
12821290
}
12831291
}
12841292
}
@@ -1320,7 +1328,7 @@ pub(crate) unsafe fn optimize(
13201328
// different code sections. We remove this attribute after Enzyme is done, to not affect the
13211329
// rest of the compilation.
13221330
// TODO: only enable this code when at least one function gets differentiated.
1323-
{
1331+
unsafe {
13241332
let mut f = LLVMGetFirstFunction(llmod);
13251333
loop {
13261334
if let Some(lf) = f {

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use crate::type_of::LayoutLlvmExt;
3939
use crate::value::Value;
4040
use crate::{attributes, llvm_util};
4141

42+
use tracing::trace;
4243

4344
pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) {
4445
let inputs = tt.args;

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ use std::mem::ManuallyDrop;
3030

3131
use back::owned_target_machine::OwnedTargetMachine;
3232
use back::write::{create_informational_target_machine, create_target_machine};
33-
use errors::ParseTargetMachineConfig;
3433
pub use llvm_util::target_features;
3534
use rustc_ast::expand::allocator::AllocatorKind;
3635
use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
@@ -265,7 +264,7 @@ impl WriteBackendMethods for LlvmCodegenBackend {
265264
) -> Result<(), FatalError> {
266265
if cgcx.lto != Lto::Fat {
267266
let dcx = cgcx.create_dcx();
268-
return Err(dcx.emit_almost_fatal(AutoDiffWithoutLTO{}));
267+
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO{}));
269268
}
270269
unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) }
271270
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![allow(non_camel_case_types)]
22
#![allow(non_upper_case_globals)]
33

4+
use tracing::trace;
45
use std::marker::PhantomData;
56

67
use libc::{c_char, c_int, c_uint, c_ulonglong, c_void, size_t};
@@ -887,7 +888,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
887888
// We don't support volatile / extern / (global?) values.
888889
// Just because I didn't had time to test them, and it seems less urgent.
889890
let args_uncacheable = vec![0; input_activity.len()];
890-
let num_fnc_args = LLVMCountParams(fnc);
891+
let num_fnc_args = unsafe{LLVMCountParams(fnc)};
891892
trace!("num_fnc_args: {}", num_fnc_args);
892893
trace!("input_activity.len(): {}", input_activity.len());
893894
assert!(num_fnc_args == input_activity.len() as u32);
@@ -914,7 +915,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
914915
trace!("input_activity i: {}", &i);
915916
}
916917
trace!("before calling Enzyme");
917-
let res = EnzymeCreateForwardDiff(
918+
let res = unsafe {EnzymeCreateForwardDiff(
918919
logic_ref, // Logic
919920
std::ptr::null(),
920921
std::ptr::null(),
@@ -932,7 +933,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
932933
args_uncacheable.as_ptr(),
933934
args_uncacheable.len(), // uncacheable arguments
934935
std::ptr::null_mut(), // write augmented function to this
935-
);
936+
)};
936937
trace!("after calling Enzyme");
937938
(res, vec![])
938939
}
@@ -975,7 +976,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
975976
// We don't support volatile / extern / (global?) values.
976977
// Just because I didn't had time to test them, and it seems less urgent.
977978
let args_uncacheable = vec![0; input_activity.len()];
978-
let num_fnc_args = LLVMCountParams(fnc);
979+
let num_fnc_args = unsafe {LLVMCountParams(fnc)};
979980
println!("num_fnc_args: {}", num_fnc_args);
980981
println!("input_activity.len(): {}", input_activity.len());
981982
assert!(num_fnc_args == input_activity.len() as u32);
@@ -1001,7 +1002,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
10011002
trace!("input_activity i: {}", &i);
10021003
}
10031004
trace!("before calling Enzyme");
1004-
let res = EnzymeCreatePrimalAndGradient(
1005+
let res = unsafe {EnzymeCreatePrimalAndGradient(
10051006
logic_ref, // Logic
10061007
std::ptr::null(),
10071008
std::ptr::null(),
@@ -1022,7 +1023,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
10221023
args_uncacheable.len(), // uncacheable arguments
10231024
std::ptr::null_mut(), // write augmented function to this
10241025
0,
1025-
);
1026+
)};
10261027
trace!("after calling Enzyme");
10271028
(res, primal_sizes)
10281029
}

0 commit comments

Comments
 (0)