Skip to content

Commit 494b102

Browse files
committed
It works (somewhat)
1 parent c97f625 commit 494b102

File tree

2 files changed

+44
-27
lines changed

2 files changed

+44
-27
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,15 @@ pub fn expand(
7272
ts.push(TokenTree::Token(t, Spacing::Joint));
7373
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
7474
}
75-
dbg!(&ts);
7675
let ts: TokenStream = TokenStream::from_iter(ts);
77-
dbg!(&ts);
7876

7977
let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret);
8078
dbg!(&x);
81-
//let span = ecx.with_def_site_ctxt(sig_span);
8279
let span = ecx.with_def_site_ctxt(expand_span);
83-
//let span = ecx.with_def_site_ctxt(fn_item.span);
8480

85-
let (d_sig, old_names, new_args) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span);
81+
let (d_sig, old_names, new_args, idents) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span);
8682
let new_decl_span = d_sig.span;
87-
//let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, span);
88-
let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span, new_decl_span);
83+
let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span, new_decl_span, &sig, &d_sig, idents);
8984
let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident;
9085

9186
// The first element of it is the name of the function to be generated
@@ -147,14 +142,13 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
147142
// The second will just take the shadow arguments.
148143
// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
149144
// (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
150-
fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span, new_decl_span: Span) -> P<ast::Block> {
145+
fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span, new_decl_span: Span, sig: &ast::FnSig, d_sig: &ast::FnSig, idents: Vec<Ident>) -> P<ast::Block> {
151146
let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]);
152147
let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]);
153148
let empty_loop_block = ecx.block(span, ThinVec::new());
154149
let loop_expr = ecx.expr_loop(span, empty_loop_block);
155150

156151

157-
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
158152
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
159153
let zeroed_call_expr = ecx.expr_path(ecx.path(span, zeroed_path));
160154

@@ -172,18 +166,11 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
172166
could_be_bare_literal: false,
173167
}));
174168
// create ::core::hint::black_box(array(arr));
175-
let primal_call = ecx.expr_call(
176-
new_decl_span,
177-
primal_call_expr,
178-
old_names.iter().map(|name| {
179-
ecx.expr_path(ecx.path_ident(new_decl_span, Ident::from_str(name)))
180-
}).collect(),
181-
);
182-
let black_box0 = ecx.expr_call(
183-
new_decl_span,
184-
blackbox_call_expr.clone(),
185-
thin_vec![primal_call.clone()],
186-
);
169+
//let black_box0 = ecx.expr_call(
170+
// new_decl_span,
171+
// blackbox_call_expr.clone(),
172+
// thin_vec![primal_call.clone()],
173+
//);
187174

188175
// create ::core::hint::black_box(grad_arr, tang_y));
189176
let black_box1 = ecx.expr_call(
@@ -201,30 +188,54 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
201188
thin_vec![unsafe_block_with_zeroed_call.clone()],
202189
);
203190

191+
let primal_call = gen_primal_call(ecx, span, primal, sig, idents);
192+
204193
let mut body = ecx.block(span, ThinVec::new());
205-
//body.stmts.push(ecx.stmt_expr(primal_call));
194+
body.stmts.push(ecx.stmt_semi(primal_call));
206195
//body.stmts.push(ecx.stmt_expr(black_box0));
207196
//body.stmts.push(ecx.stmt_expr(black_box1));
208-
body.stmts.push(ecx.stmt_expr(black_box2));
197+
//body.stmts.push(ecx.stmt_expr(black_box2));
209198
body.stmts.push(ecx.stmt_expr(loop_expr));
210199
body
211200
}
212201

202+
fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSig, idents: Vec<Ident>) -> P<ast::Expr>{
203+
//pub struct Param {
204+
// pub attrs: AttrVec,
205+
// pub ty: P<Ty>,
206+
// pub pat: P<Pat>,
207+
// pub id: NodeId,
208+
// pub span: Span,
209+
// pub is_placeholder: bool,
210+
//}
211+
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
212+
let args = idents.iter().map(|arg| {
213+
ecx.expr_path(ecx.path_ident(span, *arg))
214+
}).collect();
215+
let primal_call = ecx.expr_call(
216+
span,
217+
primal_call_expr,
218+
args,
219+
);
220+
primal_call
221+
}
222+
213223
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
214224
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
215225
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
216226
// zero-initialized by Enzyme). Active arguments are not handled yet.
217227
// Each argument of the primal function (and the return type if existing) must be annotated with an
218228
// activity.
219229
fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, _sig_span: Span)
220-
-> (ast::FnSig, Vec<String>, Vec<String>) {
230+
-> (ast::FnSig, Vec<String>, Vec<String>, Vec<Ident>) {
221231
let decl: P<ast::FnDecl> = sig.decl.clone();
222232
assert!(decl.inputs.len() == x.input_activity.len());
223233
assert!(decl.output.has_ret() == x.has_ret_activity());
224234
let mut d_decl = decl.clone();
225235
let mut d_inputs = Vec::new();
226236
let mut new_inputs = Vec::new();
227237
let mut old_names = Vec::new();
238+
let mut idents = Vec::new();
228239
for (arg, activity) in decl.inputs.iter().zip(x.input_activity.iter()) {
229240
//dbg!(&arg);
230241
d_inputs.push(arg.clone());
@@ -234,6 +245,7 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
234245
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
235246
// adjust name depending on mode
236247
let old_name = if let PatKind::Ident(_, ident, _) = shadow_arg.pat.kind {
248+
idents.push(ident.clone());
237249
ident.name
238250
} else {
239251
dbg!(&shadow_arg.pat);
@@ -247,17 +259,18 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
247259
};
248260
dbg!(&name);
249261
new_inputs.push(name.clone());
262+
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
250263
shadow_arg.pat = P(ast::Pat {
251264
// TODO: Check id
252265
id: ast::DUMMY_NODE_ID,
253266
kind: PatKind::Ident(BindingAnnotation::NONE,
254-
Ident::from_str_and_span(&name, shadow_arg.pat.span),
267+
ident,
255268
None,
256269
),
257270
span: shadow_arg.pat.span,
258271
tokens: shadow_arg.pat.tokens.clone(),
259272
});
260-
273+
//idents.push(ident);
261274
d_inputs.push(shadow_arg);
262275
}
263276
_ => {},
@@ -269,5 +282,5 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
269282
decl: d_decl,
270283
span,
271284
};
272-
(d_sig, old_names, new_inputs)
285+
(d_sig, old_names, new_inputs, idents)
273286
}

compiler/rustc_expand/src/build.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ impl<'a> ExtCtxt<'a> {
157157
ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Expr(expr) }
158158
}
159159

160+
pub fn stmt_semi(&self, expr: P<ast::Expr>) -> ast::Stmt {
161+
ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Semi(expr) }
162+
}
163+
160164
pub fn stmt_let_pat(&self, sp: Span, pat: P<ast::Pat>, ex: P<ast::Expr>) -> ast::Stmt {
161165
let local = P(ast::Local {
162166
pat,

0 commit comments

Comments
 (0)