Skip to content

Commit e3ba2d2

Browse files
committed
I will clean this up, I promise!
1 parent 001bc8a commit e3ba2d2

File tree

8 files changed

+92
-28
lines changed

8 files changed

+92
-28
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,13 @@ pub struct AutoDiffAttrs {
9090
pub input_activity: Vec<DiffActivity>,
9191
}
9292

93-
fn name(x: &NestedMetaItem) -> String {
93+
fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident {
9494
let segments = &x.meta_item().unwrap().path.segments;
9595
assert!(segments.len() == 1);
96-
segments[0].ident.name.to_string()
96+
segments[0].ident
97+
}
98+
fn name(x: &NestedMetaItem) -> String {
99+
first_ident(x).name.to_string()
97100
}
98101

99102
impl AutoDiffAttrs{
@@ -143,6 +146,13 @@ impl AutoDiffAttrs {
143146
input_activity: Vec::new(),
144147
}
145148
}
149+
pub fn source() -> Self {
150+
AutoDiffAttrs {
151+
mode: DiffMode::Source,
152+
ret_activity: DiffActivity::None,
153+
input_activity: Vec::new(),
154+
}
155+
}
146156

147157
pub fn is_active(&self) -> bool {
148158
match self.mode {

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![allow(unused_imports)]
22
#![allow(unused_variables)]
3+
#![allow(unused_mut)]
34
//use crate::util::check_builtin_macro_attribute;
45
//use crate::util::check_autodiff;
56

@@ -9,12 +10,20 @@ use rustc_ast::ptr::P;
910
use rustc_ast::{BindingAnnotation, ByRef};
1011
use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind, NestedMetaItem, MetaItemKind};
1112
use rustc_ast::{Fn, ItemKind, Stmt, TyKind, Unsafe, PatKind};
13+
use rustc_ast::tokenstream::*;
1214
use rustc_expand::base::{Annotatable, ExtCtxt};
1315
use rustc_span::symbol::{kw, sym, Ident};
1416
use rustc_span::Span;
1517
use thin_vec::{thin_vec, ThinVec};
1618
use rustc_span::Symbol;
1719
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
20+
use rustc_ast::token::{Token, TokenKind};
21+
22+
fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident {
23+
let segments = &x.meta_item().unwrap().path.segments;
24+
assert!(segments.len() == 1);
25+
segments[0].ident
26+
}
1827

1928
pub fn expand(
2029
ecx: &mut ExtCtxt<'_>,
@@ -33,7 +42,8 @@ pub fn expand(
3342
return vec![item];
3443
}
3544
};
36-
let orig_item: P<ast::Item> = item.clone().expect_item();
45+
let mut orig_item: P<ast::Item> = item.clone().expect_item();
46+
//dbg!(&orig_item.tokens);
3747
let primal = orig_item.ident.clone();
3848

3949
// Allow using `#[autodiff(...)]` only on a Fn
@@ -47,6 +57,25 @@ pub fn expand(
4757
.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
4858
return vec![item];
4959
};
60+
// create TokenStream from vec elemtents:
61+
// meta_item doesn't have a .tokens field
62+
let ts: Vec<Token> = meta_item_vec.clone()[1..].iter().map(|x| {
63+
let val = first_ident(x);
64+
let t = Token::from_ast_ident(val);
65+
t
66+
}).collect();
67+
let comma: Token = Token::new(TokenKind::Comma, Span::default());
68+
let mut ts: Vec<TokenTree> = vec![];
69+
for t in meta_item_vec.clone()[1..].iter() {
70+
let val = first_ident(t);
71+
let t = Token::from_ast_ident(val);
72+
ts.push(TokenTree::Token(t, Spacing::Joint));
73+
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
74+
}
75+
dbg!(&ts);
76+
let ts: TokenStream = TokenStream::from_iter(ts);
77+
dbg!(&ts);
78+
5079
let x: AutoDiffAttrs = AutoDiffAttrs::from_ast(&meta_item_vec, has_ret);
5180
dbg!(&x);
5281
//let span = ecx.with_def_site_ctxt(sig_span);
@@ -66,7 +95,29 @@ pub fn expand(
6695
generics: Generics::default(),
6796
body: Some(d_body),
6897
}));
69-
let d_fn = ecx.item(span, d_ident, rustc_ast::AttrVec::default(), asdf);
98+
let mut tmp = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::autodiff_into)));
99+
let mut attr: ast::Attribute = ast::Attribute {
100+
kind: ast::AttrKind::Normal(tmp.clone()),
101+
id: ast::AttrId::from_u32(0),
102+
style: ast::AttrStyle::Outer,
103+
span: span,
104+
};
105+
orig_item.attrs.push(attr);
106+
107+
// Now update for d_fn
108+
tmp.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
109+
dspan: DelimSpan::dummy(),
110+
delim: rustc_ast::token::Delimiter::Parenthesis,
111+
tokens: ts,
112+
});
113+
let mut attr2: ast::Attribute = ast::Attribute {
114+
kind: ast::AttrKind::Normal(tmp),
115+
id: ast::AttrId::from_u32(0),
116+
style: ast::AttrStyle::Outer,
117+
span: span,
118+
};
119+
let attr_vec: rustc_ast::AttrVec = thin_vec![attr2];
120+
let d_fn = ecx.item(span, d_ident, attr_vec, asdf);
70121

71122
let orig_annotatable = Annotatable::Item(orig_item.clone());
72123
let d_annotatable = Annotatable::Item(d_fn);

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ impl WriteBackendMethods for LlvmCodegenBackend {
269269
config: &ModuleConfig,
270270
) -> Result<(), FatalError> {
271271
dbg!("cg_llvm autodiff");
272+
dbg!("Differentiating {} functions", diff_fncs.len());
272273
unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) }
273274
}
274275

compiler/rustc_codegen_ssa/src/back/write.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ fn generate_lto_work<B: ExtraBackendMethods>(
382382
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
383383
) -> Vec<(WorkItem<B>, u64)> {
384384
let _prof_timer = cgcx.prof.generic_activity("codegen_generate_lto_work");
385+
dbg!("Differentiating {} functions", autodiff.len());
385386

386387
if !needs_fat_lto.is_empty() {
387388
assert!(needs_thin_lto.is_empty());

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -693,16 +693,15 @@ fn check_link_name_xor_ordinal(
693693
}
694694

695695
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
696-
//let attrs = tcx.get_attrs(id, sym::autodiff_into);
697-
let attrs = tcx.get_attrs(id, sym::autodiff);
696+
let attrs = tcx.get_attrs(id, sym::autodiff_into);
698697

699698
let attrs = attrs
700699
.into_iter()
701-
.filter(|attr| attr.name_or_empty() == sym::autodiff)
702-
//.filter(|attr| attr.name_or_empty() == sym::autodiff_into)
700+
.filter(|attr| attr.name_or_empty() == sym::autodiff_into)
703701
.collect::<Vec<_>>();
704-
if attrs.len() > 0 {
705-
dbg!("autodiff_attrs len = > 0: {}", attrs.len());
702+
703+
if !attrs.is_empty() {
704+
dbg!("autodiff_attrs amount = {}", attrs.len());
706705
}
707706

708707
// check for exactly one autodiff attribute on extern block
@@ -723,18 +722,12 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
723722
let list = attr.meta_item_list().unwrap_or_default();
724723

725724
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
726-
if list.len() == 0 {
727-
return AutoDiffAttrs {
728-
mode: DiffMode::Source,
729-
ret_activity: DiffActivity::None,
730-
input_activity: Vec::new(),
731-
};
732-
}
725+
if list.len() == 0 { return AutoDiffAttrs::source(); }
733726

734727
let msg_ad_mode = "autodiff attribute must contain autodiff mode";
735-
let mode = match &list[0] {
736-
NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => {
737-
p2.segments.first().unwrap().ident
728+
let (mode, list) = match list.split_first() {
729+
Some((NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }), list)) => {
730+
(p1.segments.first().unwrap().ident, list)
738731
}
739732
_ => {
740733
tcx.sess
@@ -749,7 +742,6 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
749742
// parse mode
750743
let msg_mode = "mode should be either forward or reverse";
751744
let mode = match mode.as_str() {
752-
//map(|x| x.as_str()) {
753745
"Forward" => DiffMode::Forward,
754746
"Reverse" => DiffMode::Reverse,
755747
_ => {
@@ -763,9 +755,9 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
763755
};
764756

765757
let msg_ret_activity = "autodiff attribute must contain the return activity";
766-
let ret_symbol = match &list[1] {
767-
NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => {
768-
p2.segments.first().unwrap().ident
758+
let (ret_symbol, list) = match list.split_last() {
759+
Some((NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }), list)) => {
760+
(p1.segments.first().unwrap().ident, list)
769761
}
770762
_ => {
771763
tcx.sess
@@ -792,7 +784,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
792784

793785
let msg_arg_activity = "autodiff attribute must contain the return activity";
794786
let mut arg_activities: Vec<DiffActivity> = vec![];
795-
for arg in &list[2..] {
787+
for arg in list {
796788
let arg_symbol = match arg {
797789
NestedMetaItem::MetaItem(MetaItem {
798790
path: ref p2, kind: MetaItemKind::Word, ..
@@ -846,6 +838,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
846838
if ret_activity == DiffActivity::Duplicated
847839
|| ret_activity == DiffActivity::DuplicatedNoNeed
848840
{
841+
dbg!("ret_activity = {:?}", ret_activity);
849842
tcx.sess
850843
.struct_span_err(
851844
attr.span, msg_rev_incompatible_arg,

compiler/rustc_feature/src/builtin_attrs.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,13 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[
361361
large_assignments, experimental!(move_size_limit)
362362
),
363363

364+
// Autodiff
365+
ungated!(
366+
autodiff_into, Normal,
367+
template!(Word, List: r#""...""#),
368+
DuplicatesOk,
369+
),
370+
364371
// Entry point:
365372
gated!(unix_sigpipe, Normal, template!(Word, NameValueStr: "inherit|sig_ign|sig_dfl"), ErrorFollowing, experimental!(unix_sigpipe)),
366373
ungated!(start, Normal, template!(Word), WarnFollowing),

compiler/rustc_monomorphize/src/partitioning.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au
11841184
println!("item: {:?}", item);
11851185
let source = usage_map.used_map.get(&item).unwrap()
11861186
.into_iter()
1187-
.filter_map(|item| match *item {
1187+
.find_map(|item| match *item {
11881188
MonoItem::Fn(ref instance_s) => {
11891189
let source_id = instance_s.def_id();
11901190
println!("source_id_inner: {:?}", source_id);
@@ -1205,8 +1205,8 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au
12051205
None
12061206
}
12071207
_ => None,
1208-
})
1209-
.next();
1208+
});
1209+
//.next();
12101210
println!("source: {:?}", source);
12111211

12121212
source.map(|inst| {

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ symbols! {
440440
augmented_assignments,
441441
auto_traits,
442442
autodiff,
443+
autodiff_into,
443444
automatically_derived,
444445
avx,
445446
avx512_target_feature,

0 commit comments

Comments
 (0)