Skip to content

Commit 948848d

Browse files
authored
support self ty (#110)
* finish method / trait support
1 parent af4d766 commit 948848d

File tree

2 files changed

+106
-32
lines changed

2 files changed

+106
-32
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 102 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ use std::string::String;
1919
use thin_vec::{thin_vec, ThinVec};
2020
use std::str::FromStr;
2121

22+
use rustc_ast::AssocItemKind;
23+
2224
#[cfg(not(llvm_enzyme))]
2325
pub fn expand(
2426
ecx: &mut ExtCtxt<'_>,
@@ -82,30 +84,62 @@ pub fn expand(
8284
ecx: &mut ExtCtxt<'_>,
8385
expand_span: Span,
8486
meta_item: &ast::MetaItem,
85-
item: Annotatable,
87+
mut item: Annotatable,
8688
) -> Vec<Annotatable> {
8789
//check_builtin_macro_attribute(ecx, meta_item, sym::alloc_error_handler);
8890

91+
// first get the annotable item:
92+
let (sig, is_impl): (FnSig, bool) = match &item {
93+
Annotatable::Item(ref iitem) => {
94+
let sig = match &iitem.kind {
95+
ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
96+
_ => {
97+
ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
98+
return vec![item];
99+
}
100+
};
101+
(sig.clone(), false)
102+
},
103+
Annotatable::ImplItem(ref assoc_item) => {
104+
let sig = match &assoc_item.kind {
105+
ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig,
106+
_ => {
107+
ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
108+
return vec![item];
109+
}
110+
};
111+
(sig.clone(), true)
112+
},
113+
_ => {
114+
dbg!(&item);
115+
ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
116+
return vec![item];
117+
}
118+
};
119+
89120
let meta_item_vec: ThinVec<NestedMetaItem> = match meta_item.kind {
90121
ast::MetaItemKind::List(ref vec) => vec.clone(),
91122
_ => {
92123
ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
93124
return vec![item];
94125
}
95126
};
96-
// Allow using `#[autodiff(...)]` only on a Fn
97-
let (has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item
98-
&& let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind
99-
{
100-
(sig.decl.output.has_ret(), sig, ecx.with_call_site_ctxt(sig.span))
101-
} else {
102-
ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
103-
return vec![item];
104-
};
105127

106-
// Now we know that item is a Item::Fn
107-
let mut orig_item: P<ast::Item> = item.clone().expect_item();
108-
let primal = orig_item.ident.clone();
128+
let has_ret = sig.decl.output.has_ret();
129+
let sig_span = ecx.with_call_site_ctxt(sig.span);
130+
131+
let (vis, primal) = match &item {
132+
Annotatable::Item(ref iitem) => {
133+
(iitem.vis.clone(), iitem.ident.clone())
134+
},
135+
Annotatable::ImplItem(ref assoc_item) => {
136+
(assoc_item.vis.clone(), assoc_item.ident.clone())
137+
},
138+
_ => {
139+
ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
140+
return vec![item];
141+
}
142+
};
109143

110144
// create TokenStream from vec elemtents:
111145
// meta_item doesn't have a .tokens field
@@ -154,12 +188,12 @@ pub fn expand(
154188
let d_ident = first_ident(&meta_item_vec[0]);
155189

156190
// The first element of it is the name of the function to be generated
157-
let asdf = ItemKind::Fn(Box::new(ast::Fn {
191+
let asdf = Box::new(ast::Fn {
158192
defaultness: ast::Defaultness::Final,
159193
sig: d_sig,
160194
generics: Generics::default(),
161195
body: Some(d_body),
162-
}));
196+
});
163197
let mut rustc_ad_attr =
164198
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
165199
let ts2: Vec<TokenTree> = vec![
@@ -195,13 +229,6 @@ pub fn expand(
195229
style: ast::AttrStyle::Outer,
196230
span,
197231
};
198-
// don't add it multiple times:
199-
if !orig_item.attrs.iter().any(|a| a.id == attr.id) {
200-
orig_item.attrs.push(attr.clone());
201-
}
202-
if !orig_item.attrs.iter().any(|a| a.id == inline_never.id) {
203-
orig_item.attrs.push(inline_never);
204-
}
205232

206233
// Now update for d_fn
207234
rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
@@ -210,13 +237,51 @@ pub fn expand(
210237
tokens: ts,
211238
});
212239
attr.kind = ast::AttrKind::Normal(rustc_ad_attr);
213-
let mut d_fn = ecx.item(span, d_ident, thin_vec![attr], asdf);
214240

215-
// Copy visibility from original function
216-
d_fn.vis = orig_item.vis.clone();
241+
// Don't add it multiple times:
242+
let orig_annotatable: Annotatable = match item {
243+
Annotatable::Item(ref mut iitem) => {
244+
if !iitem.attrs.iter().any(|a| a.id == attr.id) {
245+
iitem.attrs.push(attr.clone());
246+
}
247+
if !iitem.attrs.iter().any(|a| a.id == inline_never.id) {
248+
iitem.attrs.push(inline_never.clone());
249+
}
250+
Annotatable::Item(iitem.clone())
251+
},
252+
Annotatable::ImplItem(ref mut assoc_item) => {
253+
if !assoc_item.attrs.iter().any(|a| a.id == attr.id) {
254+
assoc_item.attrs.push(attr.clone());
255+
}
256+
if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) {
257+
assoc_item.attrs.push(inline_never.clone());
258+
}
259+
Annotatable::ImplItem(assoc_item.clone())
260+
},
261+
_ => {
262+
panic!("not supported");
263+
}
264+
};
265+
266+
let d_annotatable = if is_impl {
267+
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
268+
let d_fn = P(ast::AssocItem {
269+
attrs: thin_vec![attr.clone(), inline_never],
270+
id: ast::DUMMY_NODE_ID,
271+
span,
272+
vis,
273+
ident: d_ident,
274+
kind: assoc_item,
275+
tokens: None,
276+
});
277+
Annotatable::ImplItem(d_fn)
278+
} else {
279+
let mut d_fn = ecx.item(span, d_ident, thin_vec![attr.clone()], ItemKind::Fn(asdf));
280+
d_fn.vis = vis;
281+
Annotatable::Item(d_fn)
282+
};
283+
trace!("Generated function: {:?}", d_annotatable);
217284

218-
let orig_annotatable = Annotatable::Item(orig_item);
219-
let d_annotatable = Annotatable::Item(d_fn);
220285
return vec![orig_annotatable, d_annotatable];
221286
}
222287

@@ -403,10 +468,16 @@ fn gen_primal_call(
403468
primal: Ident,
404469
idents: Vec<Ident>,
405470
) -> P<ast::Expr> {
406-
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
407-
let args = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
408-
let primal_call = ecx.expr_call(span, primal_call_expr, args);
409-
primal_call
471+
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
472+
if has_self {
473+
let args: ThinVec<_> = idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
474+
let self_expr = ecx.expr_self(span);
475+
ecx.expr_method_call(span, self_expr, primal, args.clone())
476+
} else {
477+
let args: ThinVec<_> = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
478+
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
479+
ecx.expr_call(span, primal_call_expr, args)
480+
}
410481
}
411482

412483
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
@@ -427,7 +498,6 @@ fn gen_enzyme_decl(
427498
let mut d_decl = sig.decl.clone();
428499
let mut d_inputs = Vec::new();
429500
let mut new_inputs = Vec::new();
430-
//let mut old_names = Vec::new();
431501
let mut idents = Vec::new();
432502
let mut act_ret = ThinVec::new();
433503
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {

compiler/rustc_expand/src/build.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ impl<'a> ExtCtxt<'a> {
287287
self.expr(sp, ast::ExprKind::Paren(e))
288288
}
289289

290+
pub fn expr_method_call(&self, span: Span, expr: P<ast::Expr>, ident: Ident, args: ThinVec<P<ast::Expr>>) -> P<ast::Expr> {
291+
let seg = ast::PathSegment::from_ident(ident);
292+
self.expr(span, ast::ExprKind::MethodCall(Box::new(ast::MethodCall { seg, receiver: expr, args, span })))
293+
}
290294
pub fn expr_call(
291295
&self,
292296
span: Span,

0 commit comments

Comments
 (0)