Skip to content

Commit 12b180e

Browse files
committed
make d_fnc decl a bit more precise
1 parent 59bc30f commit 12b180e

File tree

1 file changed

+54
-29
lines changed

1 file changed

+54
-29
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,23 @@ pub fn expand(
9090
generics: Generics::default(),
9191
body: Some(d_body),
9292
}));
93-
let mut tmp = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
93+
let mut rustc_ad_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
9494
let mut attr: ast::Attribute = ast::Attribute {
95-
kind: ast::AttrKind::Normal(tmp.clone()),
95+
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
9696
id: ast::AttrId::from_u32(0),
9797
style: ast::AttrStyle::Outer,
9898
span: span,
9999
};
100-
orig_item.attrs.push(attr);
100+
orig_item.attrs.push(attr.clone());
101101

102102
// Now update for d_fn
103-
tmp.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
103+
rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
104104
dspan: DelimSpan::dummy(),
105105
delim: rustc_ast::token::Delimiter::Parenthesis,
106106
tokens: ts,
107107
});
108108
let mut attr2: ast::Attribute = ast::Attribute {
109-
kind: ast::AttrKind::Normal(tmp),
109+
kind: ast::AttrKind::Normal(rustc_ad_attr),
110110
id: ast::AttrId::from_u32(0),
111111
style: ast::AttrStyle::Outer,
112112
span: span,
@@ -165,12 +165,13 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
165165
tokens: None,
166166
could_be_bare_literal: false,
167167
}));
168+
let primal_call = gen_primal_call(ecx, span, primal, sig, idents);
168169
// create ::core::hint::black_box(array(arr));
169-
//let black_box0 = ecx.expr_call(
170-
// new_decl_span,
171-
// blackbox_call_expr.clone(),
172-
// thin_vec![primal_call.clone()],
173-
//);
170+
let black_box0 = ecx.expr_call(
171+
new_decl_span,
172+
blackbox_call_expr.clone(),
173+
thin_vec![primal_call.clone()],
174+
);
174175

175176
// create ::core::hint::black_box(grad_arr, tang_y));
176177
let black_box1 = ecx.expr_call(
@@ -188,26 +189,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
188189
thin_vec![unsafe_block_with_zeroed_call.clone()],
189190
);
190191

191-
let primal_call = gen_primal_call(ecx, span, primal, sig, idents);
192192

193193
let mut body = ecx.block(span, ThinVec::new());
194194
body.stmts.push(ecx.stmt_semi(primal_call));
195-
//body.stmts.push(ecx.stmt_expr(black_box0));
196-
//body.stmts.push(ecx.stmt_expr(black_box1));
197-
//body.stmts.push(ecx.stmt_expr(black_box2));
195+
body.stmts.push(ecx.stmt_semi(black_box0));
196+
body.stmts.push(ecx.stmt_semi(black_box1));
197+
//body.stmts.push(ecx.stmt_semi(black_box2));
198198
body.stmts.push(ecx.stmt_expr(loop_expr));
199199
body
200200
}
201201

202202
fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSig, idents: Vec<Ident>) -> P<ast::Expr>{
203-
//pub struct Param {
204-
// pub attrs: AttrVec,
205-
// pub ty: P<Ty>,
206-
// pub pat: P<Pat>,
207-
// pub id: NodeId,
208-
// pub span: Span,
209-
// pub is_placeholder: bool,
210-
//}
211203
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
212204
let args = idents.iter().map(|arg| {
213205
ecx.expr_path(ecx.path_ident(span, *arg))
@@ -228,16 +220,14 @@ fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSi
228220
// activity.
229221
fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, _sig_span: Span)
230222
-> (ast::FnSig, Vec<String>, Vec<String>, Vec<Ident>) {
231-
let decl: P<ast::FnDecl> = sig.decl.clone();
232-
assert!(decl.inputs.len() == x.input_activity.len());
233-
assert!(decl.output.has_ret() == x.has_ret_activity());
234-
let mut d_decl = decl.clone();
223+
assert!(sig.decl.inputs.len() == x.input_activity.len());
224+
assert!(sig.decl.output.has_ret() == x.has_ret_activity());
225+
let mut d_decl = sig.decl.clone();
235226
let mut d_inputs = Vec::new();
236227
let mut new_inputs = Vec::new();
237228
let mut old_names = Vec::new();
238229
let mut idents = Vec::new();
239-
for (arg, activity) in decl.inputs.iter().zip(x.input_activity.iter()) {
240-
//dbg!(&arg);
230+
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
241231
d_inputs.push(arg.clone());
242232
match activity {
243233
DiffActivity::Duplicated | DiffActivity::Dual => {
@@ -273,7 +263,42 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
273263
//idents.push(ident);
274264
d_inputs.push(shadow_arg);
275265
}
276-
_ => {},
266+
_ => {dbg!(&activity);},
267+
}
268+
}
269+
270+
// If we return a scalar in the primal and the scalar is active,
271+
// then add it as last arg to the inputs.
272+
if x.mode == DiffMode::Reverse {
273+
match x.ret_activity {
274+
DiffActivity::Active => {
275+
let ty = match d_decl.output {
276+
rustc_ast::FnRetTy::Ty(ref ty) => ty.clone(),
277+
rustc_ast::FnRetTy::Default(span) => {
278+
panic!("Did not expect Default ret ty: {:?}", span);
279+
}
280+
};
281+
let name = "dret".to_string();
282+
let ident = Ident::from_str_and_span(&name, ty.span);
283+
let shadow_arg = ast::Param {
284+
attrs: ThinVec::new(),
285+
ty: ty.clone(),
286+
pat: P(ast::Pat {
287+
id: ast::DUMMY_NODE_ID,
288+
kind: PatKind::Ident(BindingAnnotation::NONE,
289+
ident,
290+
None,
291+
),
292+
span: ty.span,
293+
tokens: None,
294+
}),
295+
id: ast::DUMMY_NODE_ID,
296+
span: ty.span,
297+
is_placeholder: false,
298+
};
299+
d_inputs.push(shadow_arg);
300+
}
301+
_ => {}
277302
}
278303
}
279304
d_decl.inputs = d_inputs.into();

0 commit comments

Comments
 (0)