Skip to content

Commit 4889ad3

Browse files
committed
fix mem leak, fix logic bug, cleanup
1 parent 293ab97 commit 4889ad3

File tree

2 files changed

+43
-47
lines changed

2 files changed

+43
-47
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::errors::{
1111
};
1212
use crate::llvm::{self, DiagnosticInfo, PassManager};
1313
use crate::llvm::{
14-
enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind, BasicBlock,
14+
enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind, BasicBlock, FreeTypeAnalysis,
1515
CreateEnzymeLogic, CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction,
1616
LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet,
1717
LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext,
@@ -809,6 +809,8 @@ pub(crate) unsafe fn enzyme_ad(
809809
LLVMSetValueName2(res, name2.as_ptr(), rust_name2.len());
810810
LLVMReplaceAllUsesWith(target_fnc, res);
811811
LLVMDeleteFunction(target_fnc);
812+
// TODO: implement drop for wrapper type?
813+
FreeTypeAnalysis(type_analysis);
812814

813815
Ok(())
814816
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -925,13 +925,17 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
925925
input_tts: Vec<TypeTree>,
926926
output_tt: TypeTree,
927927
) -> &Value {
928-
let (primary_ret, diff_ret, ret_activity) = match ret_activity {
929-
DiffActivity::Const => (true, false, CDIFFE_TYPE::DFT_CONSTANT),
930-
DiffActivity::Active => (true, true, CDIFFE_TYPE::DFT_DUP_ARG),
931-
DiffActivity::ActiveOnly => (false, true, CDIFFE_TYPE::DFT_DUP_NONEED),
932-
DiffActivity::None => (false, false, CDIFFE_TYPE::DFT_CONSTANT),
928+
let (primary_ret, ret_activity) = match ret_activity {
929+
DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT),
930+
DiffActivity::Active => (true, CDIFFE_TYPE::DFT_DUP_ARG),
931+
DiffActivity::None => (false, CDIFFE_TYPE::DFT_CONSTANT),
933932
_ => panic!("Invalid return activity"),
934933
};
934+
// This only is needed for split-mode AD, which we don't support.
935+
// See Julia:
936+
// https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3132
937+
// https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3092
938+
let diff_ret = false;
935939

936940
let input_activity: Vec<CDIFFE_TYPE> = input_activity.iter().map(|&x| cdiffe_from(x)).collect();
937941

@@ -2690,7 +2694,7 @@ extern "C" {
26902694
numRules: size_t,
26912695
) -> EnzymeTypeAnalysisRef;
26922696
}
2693-
pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() }
2697+
//pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() }
26942698
pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() }
26952699
pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef { unimplemented!() }
26962700
pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef) { unimplemented!() }
@@ -2936,25 +2940,12 @@ pub use self::Enzyme_AD::*;
29362940
pub mod Enzyme_AD {
29372941
use super::*;
29382942

2939-
use super::debuginfo::{
2940-
DIArray, DIBasicType, DIBuilder, DICompositeType, DIDerivedType, DIDescriptor, DIEnumerator,
2941-
DIFile, DIFlags, DIGlobalVariableExpression, DILexicalBlock, DILocation, DINameSpace,
2942-
DISPFlags, DIScope, DISubprogram, DISubrange, DITemplateTypeParameter, DIType, DIVariable,
2943-
DebugEmissionKind, DebugNameTableKind,
2944-
};
2945-
2946-
use libc::{c_char, c_int, c_uint, size_t};
2947-
use libc::{c_ulonglong, c_void};
2948-
2949-
use std::marker::PhantomData;
2950-
2951-
use super::RustString;
2952-
use core::fmt;
2953-
use std::ffi::{CStr, CString};
2943+
use libc::{c_char, size_t};
2944+
use libc::c_void;
29542945

29552946
extern "C" {
2956-
fn EnzymeNewTypeTree() -> CTypeTreeRef;
2957-
fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
2947+
pub fn EnzymeNewTypeTree() -> CTypeTreeRef;
2948+
pub fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
29582949
pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
29592950
pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64);
29602951
}
@@ -2971,43 +2962,46 @@ extern "C" {
29712962
static mut EnzymeStrictAliasing: c_void;
29722963
}
29732964
pub fn set_max_int_offset(offset: u64) {
2965+
let offset = offset.try_into().unwrap();
29742966
unsafe {
2975-
EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxIntOffset), offset);
2967+
EnzymeSetCLInteger(std::ptr::addr_of_mut!(MaxIntOffset), offset);
29762968
}
29772969
}
29782970
pub fn set_max_type_offset(offset: u64) {
2971+
let offset = offset.try_into().unwrap();
29792972
unsafe {
2980-
EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxTypeOffset), offset);
2973+
EnzymeSetCLInteger(std::ptr::addr_of_mut!(MaxTypeOffset), offset);
29812974
}
29822975
}
29832976
pub fn set_max_type_depth(depth: u64) {
2977+
let depth = depth.try_into().unwrap();
29842978
unsafe {
2985-
EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::EnzymeMaxTypeDepth), depth);
2979+
EnzymeSetCLInteger(std::ptr::addr_of_mut!(EnzymeMaxTypeDepth), depth);
29862980
}
29872981
}
29882982
pub fn set_print_perf(print: bool) {
29892983
unsafe {
2990-
EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), print as u8);
2984+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8);
29912985
}
29922986
}
29932987
pub fn set_print_activity(print: bool) {
29942988
unsafe {
2995-
EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintActivity), print as u8);
2989+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8);
29962990
}
29972991
}
29982992
pub fn set_print_type(print: bool) {
29992993
unsafe {
3000-
EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), print as u8);
2994+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
30012995
}
30022996
}
30032997
pub fn set_print(print: bool) {
30042998
unsafe {
3005-
EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrint), print as u8);
2999+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
30063000
}
30073001
}
30083002
pub fn set_strict_aliasing(strict: bool) {
30093003
unsafe {
3010-
EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), strict as u8);
3004+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8);
30113005
}
30123006
}
30133007
extern "C" {
@@ -3074,28 +3068,28 @@ extern "C" {
30743068
numRules: size_t,
30753069
) -> EnzymeTypeAnalysisRef;
30763070
}
3077-
//extern "C" {
3078-
// pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef);
3079-
// pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef);
3080-
// pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef;
3081-
// pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef);
3082-
// pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef);
3083-
//}
3071+
extern "C" {
3072+
//pub(super) fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef);
3073+
pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef);
3074+
pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef;
3075+
pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef);
3076+
pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef);
3077+
}
30843078

30853079
extern "C" {
3086-
fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
3087-
fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
3088-
fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
3089-
fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
3090-
fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
3091-
fn EnzymeTypeTreeShiftIndiciesEq(
3080+
pub(super) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
3081+
pub(super) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
3082+
pub(super) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
3083+
pub(super) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
3084+
pub(super) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
3085+
pub(super) fn EnzymeTypeTreeShiftIndiciesEq(
30923086
arg1: CTypeTreeRef,
30933087
data_layout: *const c_char,
30943088
offset: i64,
30953089
max_size: i64,
30963090
add_offset: u64,
30973091
);
3098-
fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
3099-
fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
3092+
pub(super) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
3093+
pub(super) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
31003094
}
31013095
}

0 commit comments

Comments
 (0)