Skip to content

Commit e4b7ef1

Browse files
committed
various fn_decl fixes
1 parent 47d6d3c commit e4b7ef1

File tree

1 file changed

+73
-52
lines changed

1 file changed

+73
-52
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//use crate::util::check_autodiff;
66

77
use crate::errors;
8+
use rustc_ast::FnRetTy;
89
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
910
use rustc_ast::ptr::P;
1011
use rustc_ast::token::{Token, TokenKind};
@@ -41,7 +42,6 @@ pub fn expand(
4142
}
4243
};
4344
let mut orig_item: P<ast::Item> = item.clone().expect_item();
44-
//dbg!(&orig_item.tokens);
4545
let primal = orig_item.ident.clone();
4646

4747
// Allow using `#[autodiff(...)]` only on a Fn
@@ -77,7 +77,7 @@ pub fn expand(
7777
dbg!(&x);
7878
let span = ecx.with_def_site_ctxt(expand_span);
7979

80-
let (d_sig, old_names, new_args, idents) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span);
80+
let (d_sig, old_names, new_args, idents) = gen_enzyme_decl(&sig, &x, span);
8181
let new_decl_span = d_sig.span;
8282
let d_body = gen_enzyme_body(
8383
ecx,
@@ -147,11 +147,11 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
147147
ty
148148
}
149149

150-
// The body of our generated functions will consist of three black_Box calls.
150+
// The body of our generated functions will consist of two black_Box calls.
151151
// The first will call the primal function with the original arguments.
152-
// The second will just take the shadow arguments.
153-
// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
154-
// (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
152+
// The second will just take a tuple containing the new arguments.
153+
// This way we surpress rustc from optimizing any argument away.
154+
// The last line will 'loop {}', to match the return type of the new function
155155
fn gen_enzyme_body(
156156
ecx: &ExtCtxt<'_>,
157157
primal: Ident,
@@ -184,31 +184,25 @@ fn gen_enzyme_body(
184184
}));
185185
let primal_call = gen_primal_call(ecx, span, primal, sig, idents);
186186
// create ::core::hint::black_box(array(arr));
187-
let black_box0 =
187+
let black_box_primal_call =
188188
ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![primal_call.clone()]);
189189

190-
// create ::core::hint::black_box(grad_arr, tang_y));
191-
let black_box1 = ecx.expr_call(
192-
sig_span,
193-
blackbox_call_expr.clone(),
194-
new_names
195-
.iter()
196-
.map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
197-
.collect(),
198-
);
190+
// create ::core::hint::black_box((grad_arr, tang_y));
191+
let tup_args = new_names
192+
.iter()
193+
.map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
194+
.collect();
199195

200-
// create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
201-
let black_box2 = ecx.expr_call(
196+
let black_box_remaining_args = ecx.expr_call(
202197
sig_span,
203198
blackbox_call_expr.clone(),
204-
thin_vec![unsafe_block_with_zeroed_call.clone()],
199+
thin_vec![ecx.expr_tuple(sig_span, tup_args)],
205200
);
206201

207202
let mut body = ecx.block(span, ThinVec::new());
208203
body.stmts.push(ecx.stmt_semi(primal_call));
209-
body.stmts.push(ecx.stmt_semi(black_box0));
210-
body.stmts.push(ecx.stmt_semi(black_box1));
211-
//body.stmts.push(ecx.stmt_semi(black_box2));
204+
body.stmts.push(ecx.stmt_semi(black_box_primal_call));
205+
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
212206
body.stmts.push(ecx.stmt_expr(loop_expr));
213207
body
214208
}
@@ -233,11 +227,9 @@ fn gen_primal_call(
233227
// Each argument of the primal function (and the return type if existing) must be annotated with an
234228
// activity.
235229
fn gen_enzyme_decl(
236-
_ecx: &ExtCtxt<'_>,
237230
sig: &ast::FnSig,
238231
x: &AutoDiffAttrs,
239232
span: Span,
240-
_sig_span: Span,
241233
) -> (ast::FnSig, Vec<String>, Vec<String>, Vec<Ident>) {
242234
assert!(sig.decl.inputs.len() == x.input_activity.len());
243235
assert!(sig.decl.output.has_ret() == x.has_ret_activity());
@@ -246,15 +238,19 @@ fn gen_enzyme_decl(
246238
let mut new_inputs = Vec::new();
247239
let mut old_names = Vec::new();
248240
let mut idents = Vec::new();
241+
let mut act_ret = ThinVec::new();
249242
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
250243
d_inputs.push(arg.clone());
251244
match activity {
245+
DiffActivity::Active => {
246+
assert!(x.mode == DiffMode::Reverse);
247+
act_ret.push(arg.ty.clone());
248+
}
252249
DiffActivity::Duplicated | DiffActivity::Dual => {
253250
let mut shadow_arg = arg.clone();
254251
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
255252
// adjust name depending on mode
256-
let old_name = if let PatKind::Ident(_, ident, _) = shadow_arg.pat.kind {
257-
idents.push(ident.clone());
253+
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
258254
ident.name
259255
} else {
260256
dbg!(&shadow_arg.pat);
@@ -276,47 +272,72 @@ fn gen_enzyme_decl(
276272
span: shadow_arg.pat.span,
277273
tokens: shadow_arg.pat.tokens.clone(),
278274
});
279-
//idents.push(ident);
280275
d_inputs.push(shadow_arg);
281276
}
282277
_ => {
283278
dbg!(&activity);
284279
}
285280
}
281+
if let PatKind::Ident(_, ident, _) = arg.pat.kind {
282+
idents.push(ident.clone());
283+
} else {
284+
panic!("not an ident?");
285+
}
286286
}
287287

288288
// If we return a scalar in the primal and the scalar is active,
289289
// then add it as last arg to the inputs.
290-
if x.mode == DiffMode::Reverse {
291-
match x.ret_activity {
292-
DiffActivity::Active => {
293-
let ty = match d_decl.output {
294-
rustc_ast::FnRetTy::Ty(ref ty) => ty.clone(),
295-
rustc_ast::FnRetTy::Default(span) => {
296-
panic!("Did not expect Default ret ty: {:?}", span);
297-
}
298-
};
299-
let name = "dret".to_string();
300-
let ident = Ident::from_str_and_span(&name, ty.span);
301-
let shadow_arg = ast::Param {
302-
attrs: ThinVec::new(),
303-
ty: ty.clone(),
304-
pat: P(ast::Pat {
305-
id: ast::DUMMY_NODE_ID,
306-
kind: PatKind::Ident(BindingAnnotation::NONE, ident, None),
307-
span: ty.span,
308-
tokens: None,
309-
}),
290+
if let DiffMode::Reverse = x.mode {
291+
if let DiffActivity::Active = x.ret_activity {
292+
let ty = match d_decl.output {
293+
FnRetTy::Ty(ref ty) => ty.clone(),
294+
FnRetTy::Default(span) => {
295+
panic!("Did not expect Default ret ty: {:?}", span);
296+
}
297+
};
298+
let name = "dret".to_string();
299+
let ident = Ident::from_str_and_span(&name, ty.span);
300+
let shadow_arg = ast::Param {
301+
attrs: ThinVec::new(),
302+
ty: ty.clone(),
303+
pat: P(ast::Pat {
310304
id: ast::DUMMY_NODE_ID,
305+
kind: PatKind::Ident(BindingAnnotation::NONE, ident, None),
311306
span: ty.span,
312-
is_placeholder: false,
313-
};
314-
d_inputs.push(shadow_arg);
315-
}
316-
_ => {}
307+
tokens: None,
308+
}),
309+
id: ast::DUMMY_NODE_ID,
310+
span: ty.span,
311+
is_placeholder: false,
312+
};
313+
d_inputs.push(shadow_arg);
314+
new_inputs.push(name);
317315
}
318316
}
319317
d_decl.inputs = d_inputs.into();
318+
319+
// If we have an active input scalar, add it's gradient to the
320+
// return type. This might require changing the return type to a
321+
// tuple.
322+
if act_ret.len() > 0 {
323+
let mut ret_ty = match d_decl.output {
324+
FnRetTy::Ty(ref ty) => {
325+
act_ret.insert(0, ty.clone());
326+
let kind = TyKind::Tup(act_ret);
327+
P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
328+
}
329+
FnRetTy::Default(span) => {
330+
if act_ret.len() == 1 {
331+
act_ret[0].clone()
332+
} else {
333+
let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
334+
P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
335+
}
336+
}
337+
};
338+
d_decl.output = FnRetTy::Ty(ret_ty);
339+
}
340+
320341
let d_sig = FnSig { header: sig.header.clone(), decl: d_decl, span };
321342
(d_sig, old_names, new_inputs, idents)
322343
}

0 commit comments

Comments
 (0)