Skip to content

Commit 001bc8a

Browse files
committed
compiles and runs binaries, but without Enzyme yet
1 parent bf72f16 commit 001bc8a

File tree

2 files changed

+24
-20
lines changed

2 files changed

+24
-20
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl FromStr for DiffActivity {
8383
}
8484

8585
#[allow(dead_code)]
86-
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)]
86+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
8787
pub struct AutoDiffAttrs {
8888
pub mode: DiffMode,
8989
pub ret_activity: DiffActivity,
@@ -127,13 +127,13 @@ impl AutoDiffAttrs{
127127
}
128128
}
129129

130-
impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffAttrs {
131-
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
132-
self.mode.hash_stable(hcx, hasher);
133-
self.ret_activity.hash_stable(hcx, hasher);
134-
self.input_activity.hash_stable(hcx, hasher);
135-
}
136-
}
130+
//impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffAttrs {
131+
// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
132+
// self.mode.hash_stable(hcx, hasher);
133+
// self.ret_activity.hash_stable(hcx, hasher);
134+
// self.input_activity.hash_stable(hcx, hasher);
135+
// }
136+
//}
137137

138138
impl AutoDiffAttrs {
139139
pub fn inactive() -> Self {

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
1818

1919
pub fn expand(
2020
ecx: &mut ExtCtxt<'_>,
21-
_span: Span,
21+
expand_span: Span,
2222
meta_item: &ast::MetaItem,
2323
item: Annotatable,
2424
) -> Vec<Annotatable> {
@@ -40,7 +40,7 @@ pub fn expand(
4040
let (fn_item, has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item
4141
&& let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind
4242
{
43-
(item, sig.decl.output.has_ret(), sig, ecx.with_def_site_ctxt(sig.span))
43+
(item, sig.decl.output.has_ret(), sig, ecx.with_call_site_ctxt(sig.span))
4444
} else {
4545
ecx.sess
4646
.dcx()
@@ -49,10 +49,14 @@ pub fn expand(
4949
};
5050
let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret);
5151
dbg!(&x);
52-
let span = ecx.with_def_site_ctxt(fn_item.span);
52+
//let span = ecx.with_def_site_ctxt(sig_span);
53+
let span = ecx.with_def_site_ctxt(expand_span);
54+
//let span = ecx.with_def_site_ctxt(fn_item.span);
5355

5456
let (d_sig, old_names, new_args) = gen_enzyme_decl(ecx, &sig, &x, span, sig_span);
55-
let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span);
57+
let new_decl_span = d_sig.span;
58+
//let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, span);
59+
let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span, new_decl_span);
5660
let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident;
5761

5862
// The first element of it is the name of the function to be generated
@@ -92,7 +96,7 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
9296
// The second will just take the shadow arguments.
9397
// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
9498
// (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
95-
fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span) -> P<ast::Block> {
99+
fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span, new_decl_span: Span) -> P<ast::Block> {
96100
let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]);
97101
let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]);
98102
let empty_loop_block = ecx.block(span, ThinVec::new());
@@ -118,14 +122,14 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
118122
}));
119123
// create ::core::hint::black_box(array(arr));
120124
let primal_call = ecx.expr_call(
121-
span,
122-
primal_call_expr.clone(),
125+
new_decl_span,
126+
primal_call_expr,
123127
old_names.iter().map(|name| {
124-
ecx.expr_path(ecx.path_ident(span, Ident::from_str(name)))
128+
ecx.expr_path(ecx.path_ident(new_decl_span, Ident::from_str(name)))
125129
}).collect(),
126130
);
127131
let black_box0 = ecx.expr_call(
128-
sig_span,
132+
new_decl_span,
129133
blackbox_call_expr.clone(),
130134
thin_vec![primal_call.clone()],
131135
);
@@ -140,17 +144,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
140144
);
141145

142146
// create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
143-
let _black_box2 = ecx.expr_call(
147+
let black_box2 = ecx.expr_call(
144148
sig_span,
145149
blackbox_call_expr.clone(),
146150
thin_vec![unsafe_block_with_zeroed_call.clone()],
147151
);
148152

149153
let mut body = ecx.block(span, ThinVec::new());
150-
body.stmts.push(ecx.stmt_expr(primal_call));
154+
//body.stmts.push(ecx.stmt_expr(primal_call));
151155
//body.stmts.push(ecx.stmt_expr(black_box0));
152156
//body.stmts.push(ecx.stmt_expr(black_box1));
153-
//body.stmts.push(ecx.stmt_expr(black_box2));
157+
body.stmts.push(ecx.stmt_expr(black_box2));
154158
body.stmts.push(ecx.stmt_expr(loop_expr));
155159
body
156160
}

0 commit comments

Comments
 (0)