Skip to content

Commit 1e32009

Browse files
authored
impl new modes for higher order ad (#106)
* impl new modes for higher order ad * fix fwd void ret case
1 parent 1218cb2 commit 1e32009

File tree

5 files changed

+106
-47
lines changed

5 files changed

+106
-47
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,21 @@ pub enum DiffMode {
1111
Source,
1212
Forward,
1313
Reverse,
14+
ForwardFirst,
15+
ReverseFirst,
16+
}
17+
18+
pub fn is_rev(mode: DiffMode) -> bool {
19+
match mode {
20+
DiffMode::Reverse | DiffMode::ReverseFirst => true,
21+
_ => false,
22+
}
23+
}
24+
pub fn is_fwd(mode: DiffMode) -> bool {
25+
match mode {
26+
DiffMode::Forward | DiffMode::ForwardFirst => true,
27+
_ => false,
28+
}
1429
}
1530

1631
impl Display for DiffMode {
@@ -20,6 +35,8 @@ impl Display for DiffMode {
2035
DiffMode::Source => write!(f, "Source"),
2136
DiffMode::Forward => write!(f, "Forward"),
2237
DiffMode::Reverse => write!(f, "Reverse"),
38+
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
39+
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
2340
}
2441
}
2542
}
@@ -32,12 +49,12 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
3249
match mode {
3350
DiffMode::Inactive => false,
3451
DiffMode::Source => false,
35-
DiffMode::Forward => {
52+
DiffMode::Forward | DiffMode::ForwardFirst => {
3653
activity == DiffActivity::Dual ||
3754
activity == DiffActivity::DualOnly ||
3855
activity == DiffActivity::Const
3956
}
40-
DiffMode::Reverse => {
57+
DiffMode::Reverse | DiffMode::ReverseFirst => {
4158
activity == DiffActivity::Const ||
4259
activity == DiffActivity::Active ||
4360
activity == DiffActivity::ActiveOnly
@@ -73,13 +90,13 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
7390
return match mode {
7491
DiffMode::Inactive => false,
7592
DiffMode::Source => false,
76-
DiffMode::Forward => {
93+
DiffMode::Forward | DiffMode::ForwardFirst => {
7794
// These are the only valid cases
7895
activity == DiffActivity::Dual ||
7996
activity == DiffActivity::DualOnly ||
8097
activity == DiffActivity::Const
8198
}
82-
DiffMode::Reverse => {
99+
DiffMode::Reverse | DiffMode::ReverseFirst => {
83100
// These are the only valid cases
84101
activity == DiffActivity::Active ||
85102
activity == DiffActivity::ActiveOnly ||
@@ -137,6 +154,8 @@ impl FromStr for DiffMode {
137154
"Source" => Ok(DiffMode::Source),
138155
"Forward" => Ok(DiffMode::Forward),
139156
"Reverse" => Ok(DiffMode::Reverse),
157+
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
158+
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
140159
_ => Err(()),
141160
}
142161
}

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//use crate::util::check_autodiff;
44

55
use crate::errors;
6-
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity};
6+
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, is_fwd, is_rev, valid_input_activity, valid_ty_for_activity};
77
use rustc_ast::ptr::P;
88
use rustc_ast::token::{Token, TokenKind};
99
use rustc_ast::tokenstream::*;
@@ -308,7 +308,7 @@ fn gen_enzyme_body(
308308

309309
let primal_ret = sig.decl.output.has_ret();
310310

311-
if primal_ret && n_active == 0 && x.mode == DiffMode::Reverse {
311+
if primal_ret && n_active == 0 && is_rev(x.mode) {
312312
// We only have the primal ret.
313313
body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
314314
return body;
@@ -355,7 +355,7 @@ fn gen_enzyme_body(
355355
panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
356356
}
357357
};
358-
if x.mode == DiffMode::Forward {
358+
if is_fwd(x.mode) {
359359
if x.ret_activity == DiffActivity::Dual {
360360
assert!(d_ret_ty.len() == 2);
361361
// both should be identical, by construction
@@ -369,7 +369,7 @@ fn gen_enzyme_body(
369369
exprs.push(default_call_expr);
370370
}
371371
} else {
372-
assert!(x.mode == DiffMode::Reverse);
372+
assert!(is_rev(x.mode));
373373

374374
if primal_ret {
375375
// We have extra handling above for the primal ret
@@ -508,7 +508,7 @@ fn gen_enzyme_decl(
508508

509509
// If we return a scalar in the primal and the scalar is active,
510510
// then add it as last arg to the inputs.
511-
if let DiffMode::Reverse = x.mode {
511+
if is_rev(x.mode) {
512512
if let DiffActivity::Active = x.ret_activity {
513513
let ty = match d_decl.output {
514514
FnRetTy::Ty(ref ty) => ty.clone(),
@@ -537,7 +537,7 @@ fn gen_enzyme_decl(
537537
}
538538
d_decl.inputs = d_inputs.into();
539539

540-
if let DiffMode::Forward = x.mode {
540+
if is_fwd(x.mode) {
541541
if let DiffActivity::Dual = x.ret_activity {
542542
let ty = match d_decl.output {
543543
FnRetTy::Ty(ref ty) => ty.clone(),

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
703703

704704
unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
705705
llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize]) {
706-
dbg!("size_positions: {:?}", size_positions);
706+
707707
// first, remove all calls from fnc
708708
let bb = LLVMGetFirstBasicBlock(tgt);
709709
let br = LLVMRustGetTerminator(bb);
@@ -843,7 +843,6 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
843843

844844
// Now clean up placeholder code.
845845
LLVMRustEraseInstBefore(bb, last_inst);
846-
//dbg!(&tgt);
847846

848847
let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src));
849848
let void_type = LLVMVoidTypeInContext(llcx);
@@ -865,6 +864,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
865864
let _fnc_ok =
866865
LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction);
867866
}
867+
868868
unsafe fn get_panic_name(llmod: &llvm::Module) -> CString {
869869
// The names are mangled and their ending changes based on a hash, so just take whichever.
870870
let mut f = LLVMGetFirstFunction(llmod);
@@ -922,21 +922,7 @@ unsafe fn add_panic_msg_to_global<'a>(llmod: &'a llvm::Module, llcx: &'a llvm::C
922922
LLVMRustSetLinkage(global_var, Linkage::PrivateLinkage);
923923
LLVMSetInitializer(global_var, struct_initializer);
924924

925-
//let msg_global_name = "ad_safety_msg".to_string();
926-
//let cmsg_global_name = CString::new(msg_global_name).unwrap();
927-
//let msg = "autodiff safety check failed!";
928-
//let cmsg = CString::new(msg).unwrap();
929-
//let msg_len = msg.len();
930-
//let i8_array_type = llvm::LLVMRustArrayType(llvm::LLVMInt8TypeInContext(llcx), msg_len as u64);
931-
//let global_type = llvm::LLVMStructTypeInContext(llcx, [i8_array_type].as_mut_ptr(), 1, 0);
932-
//let string_const_val = llvm::LLVMConstStringInContext(llcx, cmsg.as_ptr() as *const c_char, msg_len as u32, 0);
933-
//let initializer = llvm::LLVMConstStructInContext(llcx, [string_const_val].as_mut_ptr(), 1, 0);
934-
//let global = llvm::LLVMAddGlobal(llmod, global_type, cmsg_global_name.as_ptr() as *const c_char);
935-
//llvm::LLVMRustSetLinkage(global, llvm::Linkage::PrivateLinkage);
936-
//llvm::LLVMSetInitializer(global, initializer);
937-
//llvm::LLVMSetUnnamedAddress(global, llvm::UnnamedAddr::Global);
938-
939-
global_var
925+
global_var
940926
}
941927

942928
// As unsafe as it can be.
@@ -947,6 +933,7 @@ pub(crate) unsafe fn enzyme_ad(
947933
llcx: &llvm::Context,
948934
diag_handler: &DiagCtxt,
949935
item: AutoDiffItem,
936+
logic_ref: EnzymeLogicRef,
950937
) -> Result<(), FatalError> {
951938
let autodiff_mode = item.attrs.mode;
952939
let rust_name = item.source;
@@ -1001,13 +988,6 @@ pub(crate) unsafe fn enzyme_ad(
1001988
item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect();
1002989
let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx);
1003990

1004-
let mut fnc_opt = false;
1005-
if std::env::var("ENZYME_ENABLE_FNC_OPT").is_ok() {
1006-
dbg!("Disabling optimizations for Enzyme");
1007-
fnc_opt = true;
1008-
}
1009-
1010-
let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt as u8);
1011991
let type_analysis: EnzymeTypeAnalysisRef =
1012992
CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0);
1013993

@@ -1026,7 +1006,18 @@ pub(crate) unsafe fn enzyme_ad(
10261006
llvm::set_print(true);
10271007
}
10281008

1029-
let mut tmp = match item.attrs.mode {
1009+
let mode = match autodiff_mode {
1010+
DiffMode::Forward => DiffMode::Forward,
1011+
DiffMode::Reverse => DiffMode::Reverse,
1012+
DiffMode::ForwardFirst => DiffMode::Forward,
1013+
DiffMode::ReverseFirst => DiffMode::Reverse,
1014+
_ => unreachable!(),
1015+
};
1016+
1017+
let void_type = LLVMVoidTypeInContext(llcx);
1018+
let return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src_fnc));
1019+
let void_ret = void_type == return_type;
1020+
let mut tmp = match mode {
10301021
DiffMode::Forward => enzyme_rust_forward_diff(
10311022
logic_ref,
10321023
type_analysis,
@@ -1035,6 +1026,7 @@ pub(crate) unsafe fn enzyme_ad(
10351026
ret_activity,
10361027
input_tts,
10371028
output_tt,
1029+
void_ret,
10381030
),
10391031
DiffMode::Reverse => enzyme_rust_reverse_diff(
10401032
logic_ref,
@@ -1052,7 +1044,6 @@ pub(crate) unsafe fn enzyme_ad(
10521044

10531045
let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res));
10541046

1055-
let void_type = LLVMVoidTypeInContext(llcx);
10561047
let rev_mode = item.attrs.mode == DiffMode::Reverse;
10571048
create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions);
10581049
// TODO: implement drop for wrapper type?
@@ -1114,8 +1105,40 @@ pub(crate) unsafe fn differentiate(
11141105
}
11151106

11161107
let differentiate = !diff_items.is_empty();
1108+
let mut first_order_items: Vec<AutoDiffItem> = vec![];
1109+
let mut higher_order_items: Vec<AutoDiffItem> = vec![];
11171110
for item in diff_items {
1118-
let res = enzyme_ad(llmod, llcx, &diag_handler, item);
1111+
if item.attrs.mode == DiffMode::ForwardFirst || item.attrs.mode == DiffMode::ReverseFirst{
1112+
first_order_items.push(item);
1113+
} else {
1114+
// default
1115+
higher_order_items.push(item);
1116+
}
1117+
}
1118+
1119+
let mut fnc_opt = false;
1120+
if std::env::var("ENZYME_ENABLE_FNC_OPT").is_ok() {
1121+
dbg!("Enable extra optimizations for Enzyme");
1122+
fnc_opt = true;
1123+
}
1124+
1125+
// If a function is a base for some higher order ad, always optimize
1126+
let fnc_opt_base = true;
1127+
let logic_ref_opt: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt_base as u8);
1128+
1129+
for item in first_order_items {
1130+
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt);
1131+
assert!(res.is_ok());
1132+
}
1133+
1134+
// For the rest, follow the user choice on debug vs release.
1135+
// Reuse the opt one if possible for better compile time (Enzyme internal caching).
1136+
let logic_ref = match fnc_opt {
1137+
true => logic_ref_opt,
1138+
false => CreateEnzymeLogic(fnc_opt as u8),
1139+
};
1140+
for item in higher_order_items {
1141+
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref);
11191142
assert!(res.is_ok());
11201143
}
11211144

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
850850
ret_diffactivity: DiffActivity,
851851
input_tts: Vec<TypeTree>,
852852
output_tt: TypeTree,
853+
void_ret: bool,
853854
) -> (&Value, Vec<usize>) {
854855
let ret_activity = cdiffe_from(ret_diffactivity);
855856
assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF);
@@ -864,12 +865,18 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
864865
input_activity.push(act);
865866
}
866867

867-
let ret_primary_ret = match ret_activity {
868-
CDIFFE_TYPE::DFT_CONSTANT => true,
869-
CDIFFE_TYPE::DFT_DUP_ARG => true,
870-
CDIFFE_TYPE::DFT_DUP_NONEED => false,
871-
_ => panic!("Implementation error in enzyme_rust_forward_diff."),
868+
// if we have void ret, this must be false;
869+
let ret_primary_ret = if void_ret {
870+
false
871+
} else {
872+
match ret_activity {
873+
CDIFFE_TYPE::DFT_CONSTANT => true,
874+
CDIFFE_TYPE::DFT_DUP_ARG => true,
875+
CDIFFE_TYPE::DFT_DUP_NONEED => false,
876+
_ => panic!("Implementation error in enzyme_rust_forward_diff."),
877+
}
872878
};
879+
trace!("ret_primary_ret: {}", &ret_primary_ret);
873880

874881
let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
875882
//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];
@@ -879,8 +886,8 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
879886
let args_uncacheable = vec![0; input_tts.len()];
880887
assert!(args_uncacheable.len() == input_activity.len());
881888
let num_fnc_args = LLVMCountParams(fnc);
882-
println!("num_fnc_args: {}", num_fnc_args);
883-
println!("input_activity.len(): {}", input_activity.len());
889+
trace!("num_fnc_args: {}", num_fnc_args);
890+
trace!("input_activity.len(): {}", input_activity.len());
884891
assert!(num_fnc_args == input_activity.len() as u32);
885892

886893
let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 };
@@ -893,6 +900,11 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
893900
KnownValues: known_values.as_mut_ptr(),
894901
};
895902

903+
trace!("ret_activity: {}", &ret_activity);
904+
for i in &input_activity {
905+
trace!("input_activity i: {}", &i);
906+
}
907+
trace!("before calling Enzyme");
896908
let res = EnzymeCreateForwardDiff(
897909
logic_ref, // Logic
898910
std::ptr::null(),
@@ -912,6 +924,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
912924
args_uncacheable.len(), // uncacheable arguments
913925
std::ptr::null_mut(), // write augmented function to this
914926
);
927+
trace!("after calling Enzyme");
915928
(res, vec![])
916929
}
917930

@@ -971,11 +984,12 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
971984
KnownValues: known_values.as_mut_ptr(),
972985
};
973986

974-
trace!("{}", &primary_ret);
975-
trace!("{}", &ret_activity);
987+
trace!("primary_ret: {}", &primary_ret);
988+
trace!("ret_activity: {}", &ret_activity);
976989
for i in &input_activity {
977-
trace!("{}", &i);
990+
trace!("input_activity i: {}", &i);
978991
}
992+
trace!("before calling Enzyme");
979993
let res = EnzymeCreatePrimalAndGradient(
980994
logic_ref, // Logic
981995
std::ptr::null(),
@@ -998,6 +1012,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
9981012
std::ptr::null_mut(), // write augmented function to this
9991013
0,
10001014
);
1015+
trace!("after calling Enzyme");
10011016
(res, primal_sizes)
10021017
}
10031018

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
749749
let mode = match mode.as_str() {
750750
"Forward" => DiffMode::Forward,
751751
"Reverse" => DiffMode::Reverse,
752+
"ForwardFirst" => DiffMode::ForwardFirst,
753+
"ReverseFirst" => DiffMode::ReverseFirst,
752754
_ => {
753755
tcx.sess
754756
.struct_span_err(attr.span, msg_mode)

0 commit comments

Comments
 (0)