Skip to content

Commit 4c2a59f

Browse files
committed
fix reverse_mode setting, various cleanups
1 parent 84c418c commit 4c2a59f

File tree

9 files changed

+75
-50
lines changed

9 files changed

+75
-50
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ pub fn expand(
9191
new_decl_span,
9292
idents,
9393
);
94-
let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident;
94+
let d_ident = first_ident(&meta_item_vec[0]);
9595

9696
// The first element of it is the name of the function to be generated
9797
let asdf = ItemKind::Fn(Box::new(ast::Fn {
@@ -102,11 +102,12 @@ pub fn expand(
102102
}));
103103
let mut rustc_ad_attr =
104104
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
105-
let attr: ast::Attribute = ast::Attribute {
105+
let mut attr: ast::Attribute = ast::Attribute {
106106
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
107-
id: ast::AttrId::from_u32(0),
107+
//id: ast::DUMMY_TR_ID,
108+
id: ast::AttrId::from_u32(12341), // TODO: fix
108109
style: ast::AttrStyle::Outer,
109-
span: span,
110+
span,
110111
};
111112
orig_item.attrs.push(attr.clone());
112113

@@ -116,21 +117,15 @@ pub fn expand(
116117
delim: rustc_ast::token::Delimiter::Parenthesis,
117118
tokens: ts,
118119
});
119-
let attr2: ast::Attribute = ast::Attribute {
120-
kind: ast::AttrKind::Normal(rustc_ad_attr),
121-
id: ast::AttrId::from_u32(0),
122-
style: ast::AttrStyle::Outer,
123-
span: span,
124-
};
125-
let attr_vec: rustc_ast::AttrVec = thin_vec![attr2];
126-
let d_fn = ecx.item(span, d_ident, attr_vec, asdf);
120+
attr.kind = ast::AttrKind::Normal(rustc_ad_attr);
121+
let d_fn = ecx.item(span, d_ident, thin_vec![attr], asdf);
127122

128-
let orig_annotatable = Annotatable::Item(orig_item.clone());
123+
let orig_annotatable = Annotatable::Item(orig_item);
129124
let d_annotatable = Annotatable::Item(d_fn);
130125
return vec![orig_annotatable, d_annotatable];
131126
}
132127

133-
// shadow arguments must be mutable references or ptrs, because Enzyme will write into them.
128+
// shadow arguments in reverse mode must be mutable references or ptrs, because Enzyme will write into them.
134129
#[cfg(llvm_enzyme)]
135130
fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
136131
let mut ty = ty.clone();
@@ -165,6 +160,25 @@ fn gen_enzyme_body(
165160
) -> P<ast::Block> {
166161
let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]);
167162
let empty_loop_block = ecx.block(span, ThinVec::new());
163+
let noop = ast::InlineAsm {
164+
template: vec![ast::InlineAsmTemplatePiece::String("NOP".to_string())],
165+
template_strs: Box::new([]),
166+
operands: vec![],
167+
clobber_abis: vec![],
168+
options: ast::InlineAsmOptions::PURE & ast::InlineAsmOptions::NOMEM,
169+
line_spans: vec![],
170+
};
171+
let noop_expr = ecx.expr_asm(span, P(noop));
172+
let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
173+
let unsf_block = ast::Block {
174+
stmts: thin_vec![ecx.stmt_semi(noop_expr)],
175+
id: ast::DUMMY_NODE_ID,
176+
tokens: None,
177+
rules: unsf,
178+
span,
179+
could_be_bare_literal: false,
180+
};
181+
let unsf_expr = ecx.expr_block(P(unsf_block));
168182
let loop_expr = ecx.expr_loop(span, empty_loop_block);
169183
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
170184
let primal_call = gen_primal_call(ecx, span, primal, idents);
@@ -185,7 +199,7 @@ fn gen_enzyme_body(
185199
);
186200

187201
let mut body = ecx.block(span, ThinVec::new());
188-
body.stmts.push(ecx.stmt_semi(primal_call));
202+
body.stmts.push(ecx.stmt_semi(unsf_expr));
189203
body.stmts.push(ecx.stmt_semi(black_box_primal_call));
190204
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
191205
body.stmts.push(ecx.stmt_expr(loop_expr));
@@ -234,15 +248,18 @@ fn gen_enzyme_decl(
234248
}
235249
DiffActivity::Duplicated | DiffActivity::Dual => {
236250
let mut shadow_arg = arg.clone();
237-
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
251+
// We += into the shadow in reverse mode.
252+
// Otherwise copy mutability of the original argument.
253+
if activity == &DiffActivity::Duplicated {
254+
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
255+
}
238256
// adjust name depending on mode
239257
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
240258
ident.name
241259
} else {
242260
dbg!(&shadow_arg.pat);
243261
panic!("not an ident?");
244262
};
245-
//old_names.push(old_name.to_string());
246263
let name: String = match x.mode {
247264
DiffMode::Reverse => format!("d{}", old_name),
248265
DiffMode::Forward => format!("b{}", old_name),

compiler/rustc_codegen_llvm/src/attributes.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Set and unset common attributes on LLVM values.
22
3+
use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs;
34
use rustc_codegen_ssa::traits::*;
45
use rustc_hir::def_id::DefId;
56
use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags;
@@ -294,7 +295,7 @@ pub fn from_fn_attrs<'ll, 'tcx>(
294295
instance: ty::Instance<'tcx>,
295296
) {
296297
let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id());
297-
let autodiff_attrs = cx.tcx.autodiff_attrs(instance.def_id());
298+
let autodiff_attrs: &AutoDiffAttrs = cx.tcx.autodiff_attrs(instance.def_id());
298299

299300
let mut to_add = SmallVec::<[_; 16]>::new();
300301

@@ -313,6 +314,8 @@ pub fn from_fn_attrs<'ll, 'tcx>(
313314
if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) {
314315
InlineAttr::Hint
315316
} else if autodiff_attrs.is_active() {
317+
dbg!("autodiff_attrs.is_active()");
318+
dbg!(&autodiff_attrs);
316319
InlineAttr::Never
317320
} else {
318321
codegen_fn_attrs.inline

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ pub(crate) unsafe fn enzyme_ad(
702702
diag_handler: &DiagCtxt,
703703
item: AutoDiffItem,
704704
) -> Result<(), FatalError> {
705+
dbg!("\n\n\n\n\n\n AUTO DIFF \n");
705706
let autodiff_mode = item.attrs.mode;
706707
let rust_name = item.source;
707708
let rust_name2 = &item.target;

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
852852
input_tts: Vec<TypeTree>,
853853
output_tt: TypeTree,
854854
) -> &Value {
855-
let mut ret_primary_ret = false;
856855
let ret_activity = cdiffe_from(ret_diffactivity);
857856
assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF);
858857
let mut input_activity: Vec<CDIFFE_TYPE> = vec![];
@@ -866,17 +865,12 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
866865
input_activity.push(act);
867866
}
868867

869-
if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG {
870-
if ret_primary_ret != true {
871-
dbg!("overwriting ret_primary_ret!");
872-
}
873-
ret_primary_ret = true;
874-
} else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED {
875-
if ret_primary_ret != false {
876-
dbg!("overwriting ret_primary_ret!");
877-
}
878-
ret_primary_ret = false;
879-
}
868+
let ret_primary_ret = match ret_activity {
869+
CDIFFE_TYPE::DFT_CONSTANT => true,
870+
CDIFFE_TYPE::DFT_DUP_ARG => true,
871+
CDIFFE_TYPE::DFT_DUP_NONEED => false,
872+
_ => panic!("Implementation error in enzyme_rust_forward_diff."),
873+
};
880874

881875
let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
882876
//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];
@@ -916,7 +910,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
916910
)
917911
}
918912

919-
#[allow(dead_code)]
920913
pub(crate) unsafe fn enzyme_rust_reverse_diff(
921914
logic_ref: EnzymeLogicRef,
922915
type_analysis: EnzymeTypeAnalysisRef,
@@ -928,7 +921,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
928921
) -> &Value {
929922
let (primary_ret, ret_activity) = match ret_activity {
930923
DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT),
931-
DiffActivity::Active => (true, CDIFFE_TYPE::DFT_DUP_ARG),
924+
DiffActivity::Active => (true, CDIFFE_TYPE::DFT_OUT_DIFF),
932925
DiffActivity::None => (false, CDIFFE_TYPE::DFT_CONSTANT),
933926
_ => panic!("Invalid return activity"),
934927
};
@@ -962,6 +955,9 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
962955
KnownValues: known_values.as_mut_ptr(),
963956
};
964957

958+
dbg!(&primary_ret);
959+
dbg!(&ret_activity);
960+
dbg!(&input_activity);
965961
let res = EnzymeCreatePrimalAndGradient(
966962
logic_ref, // Logic
967963
std::ptr::null(),

compiler/rustc_codegen_ssa/src/back/lto.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<B: WriteBackendMethods> LtoModuleCodegen<B> {
9090
LtoModuleCodegen::Fat { ref module, .. } => {
9191
B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?;
9292
}
93-
_ => {}
93+
_ => panic!("autodiff called with non-fat LTO module"),
9494
}
9595

9696
Ok(self)

compiler/rustc_codegen_ssa/src/back/write.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,9 @@ fn generate_lto_work<B: ExtraBackendMethods>(
382382
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
383383
) -> Vec<(WorkItem<B>, u64)> {
384384
let _prof_timer = cgcx.prof.generic_activity("codegen_generate_lto_work");
385-
dbg!("Differentiating {} functions", autodiff.len());
385+
//let error_msg = format!("Found {} Functions, but {} TypeTrees", autodiff.len(), typetrees.len());
386+
// Don't assert yet, bc. apparently we add them later.
387+
//assert!(autodiff.len() == typetrees.len(), "{}", error_msg);
386388

387389
if !needs_fat_lto.is_empty() {
388390
assert!(needs_thin_lto.is_empty());

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -701,33 +701,30 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
701701
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
702702

703703
let attrs = attrs
704-
.into_iter()
705704
.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff)
706705
.collect::<Vec<_>>();
707706

708-
if !attrs.is_empty() {
709-
dbg!("autodiff_attrs amount = {}", attrs.len());
710-
}
711-
712707
// check for exactly one autodiff attribute on extern block
713-
let msg_once = "autodiff attribute can only be applied once";
714-
let attr = match &attrs[..] {
715-
&[] => return AutoDiffAttrs::inactive(),
716-
&[elm] => elm,
717-
x => {
708+
let msg_once = "cg_ssa: autodiff attribute can only be applied once";
709+
let attr = match attrs.len() {
710+
0 => return AutoDiffAttrs::inactive(),
711+
1 => attrs.get(0).unwrap(),
712+
_ => {
718713
tcx.sess
719-
.struct_span_err(x[1].span, msg_once)
720-
.span_label(x[1].span, "more than one")
714+
.struct_span_err(attrs[1].span, msg_once)
715+
.span_label(attrs[1].span, "more than one")
721716
.emit();
722-
723717
return AutoDiffAttrs::inactive();
724718
}
725719
};
720+
dbg!("autodiff_attr = {:?}", &attr);
726721

727722
let list = attr.meta_item_list().unwrap_or_default();
723+
dbg!("autodiff_attrs list = {:?}", &list);
728724

729725
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
730726
if list.len() == 0 {
727+
dbg!("autodiff_attrs: source");
731728
return AutoDiffAttrs::source();
732729
}
733730

compiler/rustc_expand/src/build.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,9 @@ impl<'a> ExtCtxt<'a> {
412412
pub fn expr_loop(&self, sp: Span, block: P<ast::Block>) -> P<ast::Expr> {
413413
self.expr(sp, ast::ExprKind::Loop(block, None, sp))
414414
}
415+
pub fn expr_asm(&self, sp: Span, expr: P<ast::InlineAsm>) -> P<ast::Expr> {
416+
self.expr(sp, ast::ExprKind::InlineAsm(expr))
417+
}
415418

416419
pub fn expr_fail(&self, span: Span, msg: Symbol) -> P<ast::Expr> {
417420
self.expr_call_global(

compiler/rustc_monomorphize/src/partitioning.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,11 @@ where
255255
export_generics,
256256
);
257257

258-
//if visibility == Visibility::Hidden && can_be_internalized {
259-
let autodiff_active =
260-
characteristic_def_id.map(|x| cx.tcx.autodiff_attrs(x).is_active()).unwrap_or(false);
258+
// We can't differentiate something that got inlined.
259+
let autodiff_active = match characteristic_def_id {
260+
Some(def_id) => cx.tcx.autodiff_attrs(def_id).is_active(),
261+
None => false,
262+
};
261263

262264
if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
263265
internalization_candidates.insert(mono_item);
@@ -1166,6 +1168,10 @@ fn collect_and_partition_mono_items(
11661168
.filter_map(|(item, instance)| {
11671169
let target_id = instance.def_id();
11681170
let target_attrs = tcx.autodiff_attrs(target_id);
1171+
if target_attrs.is_source() {
1172+
dbg!("source");
1173+
dbg!(&target_attrs);
1174+
}
11691175
if !target_attrs.apply_autodiff() {
11701176
return None;
11711177
}

0 commit comments

Comments
 (0)