Skip to content

Commit bf72f16

Browse files
committed
yeet
1 parent 6700542 commit bf72f16

File tree

2 files changed

+42
-25
lines changed

2 files changed

+42
-25
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#![allow(unused_imports)]
2+
#![allow(unused_variables)]
23
//use crate::util::check_builtin_macro_attribute;
34
//use crate::util::check_autodiff;
45

@@ -32,9 +33,7 @@ pub fn expand(
3233
return vec![item];
3334
}
3435
};
35-
let input = item.clone();
3636
let orig_item: P<ast::Item> = item.clone().expect_item();
37-
let mut d_item: P<ast::Item> = item.clone().expect_item();
3837
let primal = orig_item.ident.clone();
3938

4039
// Allow using `#[autodiff(...)]` only on a Fn
@@ -46,29 +45,27 @@ pub fn expand(
4645
ecx.sess
4746
.dcx()
4847
.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
49-
return vec![input];
48+
return vec![item];
5049
};
5150
let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret);
5251
dbg!(&x);
5352
let span = ecx.with_def_site_ctxt(fn_item.span);
5453

55-
let (d_decl, old_names, new_args) = gen_enzyme_decl(ecx, &sig.decl, &x, span, sig_span);
54+
let (d_sig, old_names, new_args) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span);
5655
let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span);
57-
let meta_item_name = meta_item_vec[0].meta_item().unwrap();
58-
d_item.ident = meta_item_name.path.segments[0].ident;
59-
// update d_item
60-
if let ItemKind::Fn(box ast::Fn { sig, body, .. }) = &mut d_item.kind {
61-
*sig.decl = d_decl;
62-
*body = Some(d_body);
63-
} else {
64-
ecx.sess
65-
.dcx()
66-
.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
67-
return vec![input];
68-
}
56+
let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident;
57+
58+
// The first element of it is the name of the function to be generated
59+
let asdf = ItemKind::Fn(Box::new(ast::Fn {
60+
defaultness: ast::Defaultness::Final,
61+
sig: d_sig,
62+
generics: Generics::default(),
63+
body: Some(d_body),
64+
}));
65+
let d_fn = ecx.item(span, d_ident, rustc_ast::AttrVec::default(), asdf);
6966

7067
let orig_annotatable = Annotatable::Item(orig_item.clone());
71-
let d_annotatable = Annotatable::Item(d_item.clone());
68+
let d_annotatable = Annotatable::Item(d_fn);
7269
return vec![orig_annotatable, d_annotatable];
7370
}
7471

@@ -98,6 +95,9 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
9895
fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span) -> P<ast::Block> {
9996
let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]);
10097
let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]);
98+
let empty_loop_block = ecx.block(span, ThinVec::new());
99+
let loop_expr = ecx.expr_loop(span, empty_loop_block);
100+
101101

102102
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
103103
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
@@ -117,13 +117,18 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
117117
could_be_bare_literal: false,
118118
}));
119119
// create ::core::hint::black_box(array(arr));
120-
let _primal_call = ecx.expr_call(
120+
let primal_call = ecx.expr_call(
121121
span,
122122
primal_call_expr.clone(),
123123
old_names.iter().map(|name| {
124124
ecx.expr_path(ecx.path_ident(span, Ident::from_str(name)))
125125
}).collect(),
126126
);
127+
let black_box0 = ecx.expr_call(
128+
sig_span,
129+
blackbox_call_expr.clone(),
130+
thin_vec![primal_call.clone()],
131+
);
127132

128133
// create ::core::hint::black_box(grad_arr, tang_y));
129134
let black_box1 = ecx.expr_call(
@@ -135,15 +140,18 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
135140
);
136141

137142
// create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
138-
let black_box2 = ecx.expr_call(
143+
let _black_box2 = ecx.expr_call(
139144
sig_span,
140145
blackbox_call_expr.clone(),
141146
thin_vec![unsafe_block_with_zeroed_call.clone()],
142147
);
143148

144149
let mut body = ecx.block(span, ThinVec::new());
145-
body.stmts.push(ecx.stmt_expr(black_box1));
146-
body.stmts.push(ecx.stmt_expr(black_box2));
150+
body.stmts.push(ecx.stmt_expr(primal_call));
151+
//body.stmts.push(ecx.stmt_expr(black_box0));
152+
//body.stmts.push(ecx.stmt_expr(black_box1));
153+
//body.stmts.push(ecx.stmt_expr(black_box2));
154+
body.stmts.push(ecx.stmt_expr(loop_expr));
147155
body
148156
}
149157

@@ -153,16 +161,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
153161
// zero-initialized by Enzyme). Active arguments are not handled yet.
154162
// Each argument of the primal function (and the return type if existing) must be annotated with an
155163
// activity.
156-
fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, decl: &ast::FnDecl, x: &AutoDiffAttrs, _span: Span, _sig_span: Span)
157-
-> (ast::FnDecl, Vec<String>, Vec<String>) {
164+
fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, _sig_span: Span)
165+
-> (ast::FnSig, Vec<String>, Vec<String>) {
166+
let decl: P<ast::FnDecl> = sig.decl.clone();
158167
assert!(decl.inputs.len() == x.input_activity.len());
159168
assert!(decl.output.has_ret() == x.has_ret_activity());
160169
let mut d_decl = decl.clone();
161170
let mut d_inputs = Vec::new();
162171
let mut new_inputs = Vec::new();
163172
let mut old_names = Vec::new();
164173
for (arg, activity) in decl.inputs.iter().zip(x.input_activity.iter()) {
165-
dbg!(&arg);
174+
//dbg!(&arg);
166175
d_inputs.push(arg.clone());
167176
match activity {
168177
DiffActivity::Duplicated => {
@@ -200,5 +209,10 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, decl: &ast::FnDecl, x: &AutoDiffAttrs, _s
200209
}
201210
}
202211
d_decl.inputs = d_inputs.into();
203-
(d_decl, old_names, new_inputs)
212+
let d_sig = FnSig {
213+
header: sig.header.clone(),
214+
decl: d_decl,
215+
span,
216+
};
217+
(d_sig, old_names, new_inputs)
204218
}

compiler/rustc_expand/src/build.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,9 @@ impl<'a> ExtCtxt<'a> {
405405
pub fn expr_tuple(&self, sp: Span, exprs: ThinVec<P<ast::Expr>>) -> P<ast::Expr> {
406406
self.expr(sp, ast::ExprKind::Tup(exprs))
407407
}
408+
pub fn expr_loop(&self, sp: Span, block: P<ast::Block>) -> P<ast::Expr> {
409+
self.expr(sp, ast::ExprKind::Loop(block, None, sp))
410+
}
408411

409412
pub fn expr_fail(&self, span: Span, msg: Symbol) -> P<ast::Expr> {
410413
self.expr_call_global(

0 commit comments

Comments
 (0)