Skip to content

Commit 9cbdf68

Browse files
committed
More precise Enzyme settings
1 parent e4b7ef1 commit 9cbdf68

File tree

4 files changed

+23
-33
lines changed

4 files changed

+23
-33
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ pub enum DiffMode {
1717
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
1818
pub enum DiffActivity {
1919
None,
20-
Active,
2120
Const,
21+
Active,
22+
ActiveOnly,
2223
Dual,
23-
DualNoNeed,
24+
DualOnly,
2425
Duplicated,
25-
DuplicatedNoNeed,
26+
DuplicatedOnly,
2627
}
2728

2829
impl FromStr for DiffMode {
@@ -47,9 +48,9 @@ impl FromStr for DiffActivity {
4748
"Active" => Ok(DiffActivity::Active),
4849
"Const" => Ok(DiffActivity::Const),
4950
"Dual" => Ok(DiffActivity::Dual),
50-
"DualNoNeed" => Ok(DiffActivity::DualNoNeed),
51+
"DualOnly" => Ok(DiffActivity::DualOnly),
5152
"Duplicated" => Ok(DiffActivity::Duplicated),
52-
"DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed),
53+
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
5354
_ => Err(()),
5455
}
5556
}

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -756,8 +756,6 @@ pub(crate) unsafe fn enzyme_ad(
756756
let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx);
757757

758758
let opt = 1;
759-
let ret_primary_ret = false;
760-
let diff_primary_ret = false;
761759
let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8);
762760
let type_analysis: EnzymeTypeAnalysisRef =
763761
CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0);
@@ -784,7 +782,6 @@ pub(crate) unsafe fn enzyme_ad(
784782
src_fnc,
785783
args_activity,
786784
ret_activity,
787-
ret_primary_ret,
788785
input_tts,
789786
output_tt,
790787
),
@@ -794,8 +791,6 @@ pub(crate) unsafe fn enzyme_ad(
794791
src_fnc,
795792
args_activity,
796793
ret_activity,
797-
ret_primary_ret,
798-
diff_primary_ret,
799794
input_tts,
800795
output_tt,
801796
),

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -851,10 +851,10 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
851851
fnc: &Value,
852852
input_diffactivity: Vec<DiffActivity>,
853853
ret_diffactivity: DiffActivity,
854-
mut ret_primary_ret: bool,
855854
input_tts: Vec<TypeTree>,
856855
output_tt: TypeTree,
857856
) -> &Value {
857+
let mut ret_primary_ret = false;
858858
let ret_activity = cdiffe_from(ret_diffactivity);
859859
assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF);
860860
let mut input_activity: Vec<CDIFFE_TYPE> = vec![];
@@ -925,29 +925,22 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
925925
fnc: &Value,
926926
input_activity: Vec<DiffActivity>,
927927
ret_activity: DiffActivity,
928-
mut ret_primary_ret: bool,
929-
diff_primary_ret: bool,
930928
input_tts: Vec<TypeTree>,
931929
output_tt: TypeTree,
932930
) -> &Value {
933-
let ret_activity = cdiffe_from(ret_activity);
934-
assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF);
931+
let (primary_ret, diff_ret, ret_activity) = match ret_activity {
932+
DiffActivity::Const => (true, false, CDIFFE_TYPE::DFT_CONSTANT),
933+
DiffActivity::Active => (true, true, CDIFFE_TYPE::DFT_DUP_ARG),
934+
DiffActivity::ActiveOnly => (false, true, CDIFFE_TYPE::DFT_DUP_NONEED),
935+
DiffActivity::None => (false, false, CDIFFE_TYPE::DFT_CONSTANT),
936+
_ => panic!("Invalid return activity"),
937+
};
938+
939+
//assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF);
935940
let input_activity: Vec<CDIFFE_TYPE> = input_activity.iter().map(|&x| cdiffe_from(x)).collect();
936941

937942
dbg!(&fnc);
938943

939-
if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG {
940-
if ret_primary_ret != true {
941-
dbg!("overwriting ret_primary_ret!");
942-
}
943-
ret_primary_ret = true;
944-
} else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED {
945-
if ret_primary_ret != false {
946-
dbg!("overwriting ret_primary_ret!");
947-
}
948-
ret_primary_ret = false;
949-
}
950-
951944
let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
952945

953946
// We don't support volatile / extern / (global?) values.
@@ -977,8 +970,8 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
977970
input_activity.as_ptr(),
978971
input_activity.len(), // constant arguments
979972
type_analysis, // type analysis struct
980-
ret_primary_ret as u8,
981-
diff_primary_ret as u8, //0
973+
primary_ret as u8,
974+
diff_ret as u8, //0
982975
CDerivativeMode::DEM_ReverseModeCombined, // return value, dret_used, top_level which was 1
983976
1, // vector mode width
984977
1, // free memory
@@ -2704,12 +2697,13 @@ pub enum CDIFFE_TYPE {
27042697
fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE {
27052698
return match act {
27062699
DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT,
2707-
DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF,
27082700
DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT,
2701+
DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF,
2702+
DiffActivity::ActiveOnly => CDIFFE_TYPE::DFT_OUT_DIFF,
27092703
DiffActivity::Dual => CDIFFE_TYPE::DFT_DUP_ARG,
2710-
DiffActivity::DualNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED,
2704+
DiffActivity::DualOnly => CDIFFE_TYPE::DFT_DUP_NONEED,
27112705
DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG,
2712-
DiffActivity::DuplicatedNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED,
2706+
DiffActivity::DuplicatedOnly => CDIFFE_TYPE::DFT_DUP_NONEED,
27132707
};
27142708
}
27152709

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
839839

840840
if mode == DiffMode::Reverse {
841841
if ret_activity == DiffActivity::Duplicated
842-
|| ret_activity == DiffActivity::DuplicatedNoNeed
842+
|| ret_activity == DiffActivity::DuplicatedOnly
843843
{
844844
dbg!("ret_activity = {:?}", ret_activity);
845845
tcx.sess

0 commit comments

Comments
 (0)