Skip to content

Commit 123b0e8

Browse files
authored
simplify ffi wrappers (#169)
1 parent 1a36626 commit 123b0e8

File tree

3 files changed

+2
-39
lines changed

3 files changed

+2
-39
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,16 +1019,6 @@ pub(crate) unsafe fn enzyme_ad(
10191019
// A really simple check
10201020
assert!(src_num_args <= target_num_args);
10211021

1022-
// create enzyme typetrees
1023-
let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
1024-
let llvm_data_layout =
1025-
std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes())
1026-
.expect("got a non-UTF8 data-layout from LLVM");
1027-
1028-
let input_tts =
1029-
item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect();
1030-
let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx);
1031-
10321022
let type_analysis: EnzymeTypeAnalysisRef =
10331023
unsafe {CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0)};
10341024

@@ -1066,8 +1056,6 @@ pub(crate) unsafe fn enzyme_ad(
10661056
src_fnc,
10671057
args_activity,
10681058
ret_activity,
1069-
input_tts,
1070-
output_tt,
10711059
void_ret,
10721060
),
10731061
DiffMode::Reverse => enzyme_rust_reverse_diff(
@@ -1076,8 +1064,6 @@ pub(crate) unsafe fn enzyme_ad(
10761064
src_fnc,
10771065
args_activity,
10781066
ret_activity,
1079-
input_tts,
1080-
output_tt,
10811067
),
10821068
_ => unreachable!(),
10831069
};

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,6 @@ pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def:
8787

8888
#[allow(unused)]
8989
pub fn add_opt_dbg_helper<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, val: &'ll Value, attrs: AutoDiffAttrs, i: usize) {
90-
//pub mode: DiffMode,
91-
//pub ret_activity: DiffActivity,
92-
//pub input_activity: Vec<DiffActivity>,
9390
let inputs = attrs.input_activity;
9491
let outputs = attrs.ret_activity;
9592
let ad_name = match attrs.mode {
@@ -136,7 +133,6 @@ pub fn add_opt_dbg_helper<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Contex
136133
let num_args = llvm::LLVMCountParams(wrapper_fn);
137134
let mut args = Vec::with_capacity(num_args as usize + 1);
138135
args.push(val);
139-
// metadata !"enzyme_const"
140136
let enzyme_const = llvm::LLVMMDStringInContext(llcx, "enzyme_const".as_ptr() as *const c_char, 12);
141137
let enzyme_out = llvm::LLVMMDStringInContext(llcx, "enzyme_out".as_ptr() as *const c_char, 10);
142138
let enzyme_dup = llvm::LLVMMDStringInContext(llcx, "enzyme_dup".as_ptr() as *const c_char, 10);

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -845,15 +845,12 @@ pub enum LLVMVerifierFailureAction {
845845
LLVMReturnStatusAction,
846846
}
847847

848-
#[allow(dead_code)]
849848
pub(crate) unsafe fn enzyme_rust_forward_diff(
850849
logic_ref: EnzymeLogicRef,
851850
type_analysis: EnzymeTypeAnalysisRef,
852851
fnc: &Value,
853852
input_diffactivity: Vec<DiffActivity>,
854853
ret_diffactivity: DiffActivity,
855-
_input_tts: Vec<TypeTree>,
856-
_output_tt: TypeTree,
857854
void_ret: bool,
858855
) -> (&Value, Vec<usize>) {
859856
let ret_activity = cdiffe_from(ret_diffactivity);
@@ -882,9 +879,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
882879
};
883880
trace!("ret_primary_ret: {}", &ret_primary_ret);
884881

885-
//let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
886-
//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];
887-
888882
// We don't support volatile / extern / (global?) values.
889883
// Just because I didn't had time to test them, and it seems less urgent.
890884
let args_uncacheable = vec![0; input_activity.len()];
@@ -900,9 +894,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
900894
let tree_tmp = TypeTree::new();
901895
let mut args_tree = vec![tree_tmp.inner; input_activity.len()];
902896

903-
//let mut args_tree = vec![std::ptr::null_mut(); input_activity.len()];
904-
//let ret_tt = std::ptr::null_mut();
905-
//let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
906897
let ret_tt = TypeTree::new();
907898
let dummy_type = CFnTypeInfo {
908899
Arguments: args_tree.as_mut_ptr(),
@@ -944,8 +935,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
944935
fnc: &Value,
945936
rust_input_activity: Vec<DiffActivity>,
946937
ret_activity: DiffActivity,
947-
input_tts: Vec<TypeTree>,
948-
_output_tt: TypeTree,
949938
) -> (&Value, Vec<usize>) {
950939
let (primary_ret, ret_activity) = match ret_activity {
951940
DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT),
@@ -971,8 +960,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
971960
input_activity.push(cdiffe_from(x));
972961
}
973962

974-
//let args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
975-
976963
// We don't support volatile / extern / (global?) values.
977964
// Just because I didn't had time to test them, and it seems less urgent.
978965
let args_uncacheable = vec![0; input_activity.len()];
@@ -982,14 +969,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
982969
assert!(num_fnc_args == input_activity.len() as u32);
983970
let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 };
984971

985-
let mut known_values = vec![kv_tmp; input_tts.len()];
972+
let mut known_values = vec![kv_tmp; input_activity.len()];
986973

987974
let tree_tmp = TypeTree::new();
988-
let mut args_tree = vec![tree_tmp.inner; input_tts.len()];
989-
//let mut args_tree = vec![TypeTree::new().inner; input_tts.len()];
975+
let mut args_tree = vec![tree_tmp.inner; input_activity.len()];
990976
let ret_tt = TypeTree::new();
991-
//let mut args_tree = vec![std::ptr::null_mut(); input_tts.len()];
992-
//let ret_tt = std::ptr::null_mut();
993977
let dummy_type = CFnTypeInfo {
994978
Arguments: args_tree.as_mut_ptr(),
995979
Return: ret_tt.inner,
@@ -1029,9 +1013,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
10291013
}
10301014

10311015
extern "C" {
1032-
// TODO: can I just ignore the non void return
1033-
// EraseFromParent doesn't exist :(
1034-
//pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value;
10351016
// Enzyme
10361017
pub fn LLVMRustAddFncParamAttr<'a>(
10371018
F: &'a Value,

0 commit comments

Comments
 (0)