Skip to content

Commit 84968bf

Browse files
committed
fix monomorphization, extra dbg output
1 parent de61995 commit 84968bf

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,9 @@ pub(crate) unsafe fn enzyme_ad(
712712
let name2 = CString::new(rust_name2.clone()).unwrap();
713713
let src_fnc = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()).unwrap();
714714
let target_fnc = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()).unwrap();
715+
let src_num_args = llvm::LLVMCountParams(src_fnc);
716+
let target_num_args = llvm::LLVMCountParams(target_fnc);
717+
assert!(src_num_args <= target_num_args);
715718

716719
// create enzyme typetrees
717720
let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,8 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
914914
assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF);
915915
let input_activity: Vec<CDIFFE_TYPE> = input_activity.iter().map(|&x| cdiffe_from(x)).collect();
916916

917+
dbg!(&fnc);
918+
917919
if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG {
918920
if ret_primary_ret != true {
919921
dbg!("overwriting ret_primary_ret!");
@@ -931,6 +933,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
931933
// We don't support volatile / extern / (global?) values.
932934
// Just because I didn't had time to test them, and it seems less urgent.
933935
let args_uncacheable = vec![0; input_tts.len()];
936+
assert!(args_uncacheable.len() == input_activity.len());
937+
let num_fnc_args = LLVMCountParams(fnc);
938+
println!("num_fnc_args: {}", num_fnc_args);
939+
println!("input_activity.len(): {}", input_activity.len());
940+
assert!(num_fnc_args == input_activity.len() as u32);
934941
let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 };
935942

936943

@@ -942,7 +949,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
942949
KnownValues: known_values.as_mut_ptr(),
943950
};
944951

945-
EnzymeCreatePrimalAndGradient(
952+
let res = EnzymeCreatePrimalAndGradient(
946953
logic_ref, // Logic
947954
std::ptr::null(),
948955
std::ptr::null(),
@@ -963,7 +970,9 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
963970
args_uncacheable.len(), // uncacheable arguments
964971
std::ptr::null_mut(), // write augmented function to this
965972
0,
966-
)
973+
);
974+
dbg!(&res);
975+
res
967976
}
968977
pub type GetSymbolsCallback = unsafe extern "C" fn(*mut c_void, *const c_char) -> *mut c_void;
969978
pub type GetSymbolsErrorCallback = unsafe extern "C" fn(*const c_char) -> *mut c_void;

compiler/rustc_monomorphize/src/partitioning.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,14 @@ where
247247
&mut can_be_internalized,
248248
export_generics,
249249
);
250-
if visibility == Visibility::Hidden && can_be_internalized {
250+
//if visibility == Visibility::Hidden && can_be_internalized {
251+
252+
//dbg!(&characteristic_def_id);
253+
let autodiff_active = characteristic_def_id
254+
.map(|x| cx.tcx.autodiff_attrs(x).is_active())
255+
.unwrap_or(false);
256+
257+
if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
251258
internalization_candidates.insert(mono_item);
252259
}
253260
let size_estimate = mono_item.size_estimate(cx.tcx);

0 commit comments

Comments
 (0)