Skip to content

Commit 6700542

Browse files
committed
add minimal ast based macro parser
1 parent b394066 commit 6700542

File tree

4 files changed

+231
-57
lines changed

4 files changed

+231
-57
lines changed

compiler/rustc_ast/src/ast.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,6 +2574,12 @@ impl FnRetTy {
25742574
FnRetTy::Ty(ty) => ty.span,
25752575
}
25762576
}
2577+
pub fn has_ret(&self) -> bool {
2578+
match self {
2579+
FnRetTy::Default(_) => false,
2580+
FnRetTy::Ty(_) => true,
2581+
}
2582+
}
25772583
}
25782584

25792585
#[derive(Clone, Copy, PartialEq, Encodable, Decodable, Debug)]

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
use super::typetree::TypeTree;
2-
use std::str::FromStr;
31
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd};
42
use crate::HashStableContext;
3+
use crate::expand::typetree::TypeTree;
4+
use thin_vec::ThinVec;
5+
//use rustc_expand::base::{Annotatable, ExtCtxt};
6+
use std::str::FromStr;
7+
8+
use crate::NestedMetaItem;
59

610
#[allow(dead_code)]
711
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)]
@@ -52,12 +56,22 @@ impl<CTX: HashStableContext> HashStable<CTX> for DiffActivity {
5256
}
5357
}
5458

59+
impl FromStr for DiffMode {
60+
type Err = ();
5561

62+
fn from_str(s: &str) -> Result<DiffMode, ()> { match s {
63+
"Inactive" => Ok(DiffMode::Inactive),
64+
"Source" => Ok(DiffMode::Source),
65+
"Forward" => Ok(DiffMode::Forward),
66+
"Reverse" => Ok(DiffMode::Reverse),
67+
_ => Err(()),
68+
}
69+
}
70+
}
5671
impl FromStr for DiffActivity {
5772
type Err = ();
5873

59-
fn from_str(s: &str) -> Result<DiffActivity, ()> {
60-
match s {
74+
fn from_str(s: &str) -> Result<DiffActivity, ()> { match s {
6175
"None" => Ok(DiffActivity::None),
6276
"Active" => Ok(DiffActivity::Active),
6377
"Const" => Ok(DiffActivity::Const),
@@ -76,6 +90,43 @@ pub struct AutoDiffAttrs {
7690
pub input_activity: Vec<DiffActivity>,
7791
}
7892

93+
fn name(x: &NestedMetaItem) -> String {
94+
let segments = &x.meta_item().unwrap().path.segments;
95+
assert!(segments.len() == 1);
96+
segments[0].ident.name.to_string()
97+
}
98+
99+
impl AutoDiffAttrs{
100+
pub fn has_ret_activity(&self) -> bool {
101+
match self.ret_activity {
102+
DiffActivity::None => false,
103+
_ => true,
104+
}
105+
}
106+
pub fn from_ast(meta_item: &ThinVec<NestedMetaItem>, has_ret: bool) -> Self {
107+
let mode = name(&meta_item[1]);
108+
let mode = DiffMode::from_str(&mode).unwrap();
109+
let activities: Vec<DiffActivity> = meta_item[2..].iter().map(|x| {
110+
let activity_str = name(&x);
111+
DiffActivity::from_str(&activity_str).unwrap()
112+
}).collect();
113+
114+
// If a return type exist, we need to split the last activity,
115+
// otherwise we return None as placeholder.
116+
let (ret_activity, input_activity) = if has_ret {
117+
activities.split_last().unwrap()
118+
} else {
119+
(&DiffActivity::None, activities.as_slice())
120+
};
121+
122+
AutoDiffAttrs {
123+
mode,
124+
ret_activity: *ret_activity,
125+
input_activity: input_activity.to_vec(),
126+
}
127+
}
128+
}
129+
79130
impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffAttrs {
80131
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
81132
self.mode.hash_stable(hcx, hasher);

compiler/rustc_ast/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pub mod ptr;
5050
pub mod token;
5151
pub mod tokenstream;
5252
pub mod visit;
53+
//pub mod autodiff_attrs;
5354

5455
pub use self::ast::*;
5556
pub use self::ast_traits::{AstDeref, AstNodeWrapper, HasAttrs, HasNodeId, HasSpan, HasTokens};
Lines changed: 169 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
#![allow(unused)]
2-
3-
use crate::errors;
1+
#![allow(unused_imports)]
42
//use crate::util::check_builtin_macro_attribute;
53
//use crate::util::check_autodiff;
64

5+
use std::string::String;
6+
use crate::errors;
77
use rustc_ast::ptr::P;
8-
use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind};
9-
use rustc_ast::{Fn, ItemKind, Stmt, TyKind, Unsafe};
8+
use rustc_ast::{BindingAnnotation, ByRef};
9+
use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind, NestedMetaItem, MetaItemKind};
10+
use rustc_ast::{Fn, ItemKind, Stmt, TyKind, Unsafe, PatKind};
1011
use rustc_expand::base::{Annotatable, ExtCtxt};
1112
use rustc_span::symbol::{kw, sym, Ident};
1213
use rustc_span::Span;
1314
use thin_vec::{thin_vec, ThinVec};
1415
use rustc_span::Symbol;
16+
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
1517

1618
pub fn expand(
1719
ecx: &mut ExtCtxt<'_>,
@@ -20,69 +22,183 @@ pub fn expand(
2022
item: Annotatable,
2123
) -> Vec<Annotatable> {
2224
//check_builtin_macro_attribute(ecx, meta_item, sym::alloc_error_handler);
23-
//check_builtin_macro_attribute(ecx, meta_item, sym::autodiff);
2425

25-
dbg!(&meta_item);
26+
let meta_item_vec: ThinVec<NestedMetaItem> = match meta_item.kind {
27+
ast::MetaItemKind::List(ref vec) => vec.clone(),
28+
_ => {
29+
ecx.sess
30+
.dcx()
31+
.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
32+
return vec![item];
33+
}
34+
};
2635
let input = item.clone();
2736
let orig_item: P<ast::Item> = item.clone().expect_item();
2837
let mut d_item: P<ast::Item> = item.clone().expect_item();
38+
let primal = orig_item.ident.clone();
2939

30-
// Allow using `#[autodiff(...)]` on a Fn
31-
let (fn_item, _ty_span) = if let Annotatable::Item(item) = &item
40+
// Allow using `#[autodiff(...)]` only on a Fn
41+
let (fn_item, has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item
3242
&& let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind
3343
{
34-
dbg!(&item);
35-
(item, ecx.with_def_site_ctxt(sig.span))
44+
(item, sig.decl.output.has_ret(), sig, ecx.with_def_site_ctxt(sig.span))
3645
} else {
3746
ecx.sess
3847
.dcx()
3948
.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
4049
return vec![input];
4150
};
42-
let _x: &ItemKind = &fn_item.kind;
43-
d_item.ident.name =
44-
Symbol::intern(format!("d_{}", fn_item.ident.name).as_str());
51+
let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret);
52+
dbg!(&x);
53+
let span = ecx.with_def_site_ctxt(fn_item.span);
54+
55+
let (d_decl, old_names, new_args) = gen_enzyme_decl(ecx, &sig.decl, &x, span, sig_span);
56+
let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, sig_span);
57+
let meta_item_name = meta_item_vec[0].meta_item().unwrap();
58+
d_item.ident = meta_item_name.path.segments[0].ident;
59+
// update d_item
60+
if let ItemKind::Fn(box ast::Fn { sig, body, .. }) = &mut d_item.kind {
61+
*sig.decl = d_decl;
62+
*body = Some(d_body);
63+
} else {
64+
ecx.sess
65+
.dcx()
66+
.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
67+
return vec![input];
68+
}
69+
4570
let orig_annotatable = Annotatable::Item(orig_item.clone());
4671
let d_annotatable = Annotatable::Item(d_item.clone());
4772
return vec![orig_annotatable, d_annotatable];
4873
}
4974

50-
// #[rustc_std_internal_symbol]
51-
// unsafe fn __rg_oom(size: usize, align: usize) -> ! {
52-
// handler(core::alloc::Layout::from_size_align_unchecked(size, align))
53-
// }
54-
//fn generate_handler(cx: &ExtCtxt<'_>, handler: Ident, span: Span, sig_span: Span) -> Stmt {
55-
// let usize = cx.path_ident(span, Ident::new(sym::usize, span));
56-
// let ty_usize = cx.ty_path(usize);
57-
// let size = Ident::from_str_and_span("size", span);
58-
// let align = Ident::from_str_and_span("align", span);
59-
//
60-
// let layout_new = cx.std_path(&[sym::alloc, sym::Layout, sym::from_size_align_unchecked]);
61-
// let layout_new = cx.expr_path(cx.path(span, layout_new));
62-
// let layout = cx.expr_call(
63-
// span,
64-
// layout_new,
65-
// thin_vec![cx.expr_ident(span, size), cx.expr_ident(span, align)],
66-
// );
67-
//
68-
// let call = cx.expr_call_ident(sig_span, handler, thin_vec![layout]);
69-
//
70-
// let never = ast::FnRetTy::Ty(cx.ty(span, TyKind::Never));
71-
// let params = thin_vec![cx.param(span, size, ty_usize.clone()), cx.param(span, align, ty_usize)];
72-
// let decl = cx.fn_decl(params, never);
73-
// let header = FnHeader { unsafety: Unsafe::Yes(span), ..FnHeader::default() };
74-
// let sig = FnSig { decl, header, span: span };
75-
//
76-
// let body = Some(cx.block_expr(call));
77-
// let kind = ItemKind::Fn(Box::new(Fn {
78-
// defaultness: ast::Defaultness::Final,
79-
// sig,
80-
// generics: Generics::default(),
81-
// body,
82-
// }));
83-
//
84-
// let attrs = thin_vec![cx.attr_word(sym::rustc_std_internal_symbol, span)];
85-
//
86-
// let item = cx.item(span, Ident::from_str_and_span("__rg_oom", span), attrs, kind);
87-
// cx.stmt_item(sig_span, item)
88-
//}
75+
// shadow arguments must be mutable references or ptrs, because Enzyme will write into them.
76+
fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
77+
let mut ty = ty.clone();
78+
match ty.kind {
79+
TyKind::Ptr(ref mut mut_ty) => {
80+
mut_ty.mutbl = ast::Mutability::Mut;
81+
}
82+
TyKind::Ref(_, ref mut mut_ty) => {
83+
mut_ty.mutbl = ast::Mutability::Mut;
84+
}
85+
_ => {
86+
panic!("unsupported type: {:?}", ty);
87+
}
88+
}
89+
ty
90+
}
91+
92+
93+
// The body of our generated functions will consist of three black_Box calls.
94+
// The first will call the primal function with the original arguments.
95+
// The second will just take the shadow arguments.
96+
// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
97+
// (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
98+
fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_names: &[String], span: Span, sig_span: Span) -> P<ast::Block> {
99+
let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]);
100+
let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]);
101+
102+
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
103+
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
104+
let zeroed_call_expr = ecx.expr_path(ecx.path(span, zeroed_path));
105+
106+
let mem_zeroed_call: Stmt = ecx.stmt_expr(ecx.expr_call(
107+
span,
108+
zeroed_call_expr.clone(),
109+
thin_vec![],
110+
));
111+
let unsafe_block_with_zeroed_call: P<ast::Expr> = ecx.expr_block(P(ast::Block {
112+
stmts: thin_vec![mem_zeroed_call],
113+
id: ast::DUMMY_NODE_ID,
114+
rules: ast::BlockCheckMode::Unsafe(ast::UserProvided),
115+
span: sig_span,
116+
tokens: None,
117+
could_be_bare_literal: false,
118+
}));
119+
// create ::core::hint::black_box(array(arr));
120+
let _primal_call = ecx.expr_call(
121+
span,
122+
primal_call_expr.clone(),
123+
old_names.iter().map(|name| {
124+
ecx.expr_path(ecx.path_ident(span, Ident::from_str(name)))
125+
}).collect(),
126+
);
127+
128+
// create ::core::hint::black_box(grad_arr, tang_y));
129+
let black_box1 = ecx.expr_call(
130+
sig_span,
131+
blackbox_call_expr.clone(),
132+
new_names.iter().map(|arg| {
133+
ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))
134+
}).collect(),
135+
);
136+
137+
// create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
138+
let black_box2 = ecx.expr_call(
139+
sig_span,
140+
blackbox_call_expr.clone(),
141+
thin_vec![unsafe_block_with_zeroed_call.clone()],
142+
);
143+
144+
let mut body = ecx.block(span, ThinVec::new());
145+
body.stmts.push(ecx.stmt_expr(black_box1));
146+
body.stmts.push(ecx.stmt_expr(black_box2));
147+
body
148+
}
149+
150+
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
151+
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
152+
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
153+
// zero-initialized by Enzyme). Active arguments are not handled yet.
154+
// Each argument of the primal function (and the return type if existing) must be annotated with an
155+
// activity.
156+
fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, decl: &ast::FnDecl, x: &AutoDiffAttrs, _span: Span, _sig_span: Span)
157+
-> (ast::FnDecl, Vec<String>, Vec<String>) {
158+
assert!(decl.inputs.len() == x.input_activity.len());
159+
assert!(decl.output.has_ret() == x.has_ret_activity());
160+
let mut d_decl = decl.clone();
161+
let mut d_inputs = Vec::new();
162+
let mut new_inputs = Vec::new();
163+
let mut old_names = Vec::new();
164+
for (arg, activity) in decl.inputs.iter().zip(x.input_activity.iter()) {
165+
dbg!(&arg);
166+
d_inputs.push(arg.clone());
167+
match activity {
168+
DiffActivity::Duplicated => {
169+
let mut shadow_arg = arg.clone();
170+
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
171+
// adjust name depending on mode
172+
let old_name = if let PatKind::Ident(_, ident, _) = shadow_arg.pat.kind {
173+
ident.name
174+
} else {
175+
dbg!(&shadow_arg.pat);
176+
panic!("not an ident?");
177+
};
178+
old_names.push(old_name.to_string());
179+
let name: String = match x.mode {
180+
DiffMode::Reverse => format!("d{}", old_name),
181+
DiffMode::Forward => format!("b{}", old_name),
182+
_ => panic!("unsupported mode: {}", old_name),
183+
};
184+
dbg!(&name);
185+
new_inputs.push(name.clone());
186+
shadow_arg.pat = P(ast::Pat {
187+
// TODO: Check id
188+
id: ast::DUMMY_NODE_ID,
189+
kind: PatKind::Ident(BindingAnnotation::NONE,
190+
Ident::from_str_and_span(&name, shadow_arg.pat.span),
191+
None,
192+
),
193+
span: shadow_arg.pat.span,
194+
tokens: shadow_arg.pat.tokens.clone(),
195+
});
196+
197+
d_inputs.push(shadow_arg);
198+
}
199+
_ => {},
200+
}
201+
}
202+
d_decl.inputs = d_inputs.into();
203+
(d_decl, old_names, new_inputs)
204+
}

0 commit comments

Comments
 (0)