Skip to content

Commit 12292ca

Browse files
committed
got autodiff macro to do something
1 parent 79d1834 commit 12292ca

File tree

3 files changed

+67
-68
lines changed

3 files changed

+67
-68
lines changed

compiler/rustc_builtin_macros/messages.ftl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function
22
builtin_macros_alloc_must_statics = allocators must be statics
33
4+
builtin_macros_autodiff = autodiff must be applied to function
5+
46
builtin_macros_asm_clobber_abi = clobber_abi
57
builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs
68
builtin_macros_asm_clobber_outputs = generic outputs
Lines changed: 58 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
#![allow(unused)]
2+
13
use crate::errors;
2-
use crate::util::check_builtin_macro_attribute;
4+
//use crate::util::check_builtin_macro_attribute;
35

46
use rustc_ast::ptr::P;
57
use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind};
@@ -8,6 +10,7 @@ use rustc_expand::base::{Annotatable, ExtCtxt};
810
use rustc_span::symbol::{kw, sym, Ident};
911
use rustc_span::Span;
1012
use thin_vec::{thin_vec, ThinVec};
13+
use rustc_span::Symbol;
1114

1215
pub fn expand(
1316
ecx: &mut ExtCtxt<'_>,
@@ -18,80 +21,67 @@ pub fn expand(
1821
//check_builtin_macro_attribute(ecx, meta_item, sym::alloc_error_handler);
1922
//check_builtin_macro_attribute(ecx, meta_item, sym::autodiff);
2023

21-
let orig_item = item.clone();
24+
dbg!(&meta_item);
25+
let input = item.clone();
26+
let orig_item: P<ast::Item> = item.clone().expect_item();
27+
let mut d_item: P<ast::Item> = item.clone().expect_item();
2228

23-
// Allow using `#[alloc_error_handler]` on an item statement
24-
// FIXME - if we get deref patterns, use them to reduce duplication here
25-
let (item, is_stmt, sig_span) = if let Annotatable::Item(item) = &item
26-
&& let ItemKind::Fn(fn_kind) = &item.kind
27-
{
28-
(item, false, ecx.with_def_site_ctxt(fn_kind.sig.span))
29-
} else if let Annotatable::Stmt(stmt) = &item
30-
&& let StmtKind::Item(item) = &stmt.kind
31-
&& let ItemKind::Fn(fn_kind) = &item.kind
29+
// Allow using `#[autodiff(...)]` on a Fn
30+
let (fn_item, _ty_span) = if let Annotatable::Item(item) = &item
31+
&& let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind
3232
{
33-
(item, true, ecx.with_def_site_ctxt(fn_kind.sig.span))
34-
} else {
35-
ecx.sess.dcx().emit_err(errors::AllocErrorMustBeFn { span: item.span() });
36-
return vec![orig_item];
37-
};
38-
39-
// Generate a bunch of new items using the AllocFnFactory
40-
let span = ecx.with_def_site_ctxt(item.span);
41-
42-
// Generate item statements for the allocator methods.
43-
let stmts = thin_vec![generate_handler(ecx, item.ident, span, sig_span)];
44-
45-
// Generate anonymous constant serving as container for the allocator methods.
46-
let const_ty = ecx.ty(sig_span, TyKind::Tup(ThinVec::new()));
47-
let const_body = ecx.expr_block(ecx.block(span, stmts));
48-
let const_item = ecx.item_const(span, Ident::new(kw::Underscore, span), const_ty, const_body);
49-
let const_item = if is_stmt {
50-
Annotatable::Stmt(P(ecx.stmt_item(span, const_item)))
33+
dbg!(&item);
34+
(item, ecx.with_def_site_ctxt(sig.span))
5135
} else {
52-
Annotatable::Item(const_item)
36+
ecx.sess
37+
.dcx()
38+
.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
39+
return vec![input];
5340
};
54-
55-
// Return the original item and the new methods.
56-
vec![orig_item, const_item]
41+
let _x: &ItemKind = &fn_item.kind;
42+
d_item.ident.name =
43+
Symbol::intern(format!("d_{}", fn_item.ident.name).as_str());
44+
let orig_annotatable = Annotatable::Item(orig_item.clone());
45+
let d_annotatable = Annotatable::Item(d_item.clone());
46+
return vec![orig_annotatable, d_annotatable];
5747
}
5848

5949
// #[rustc_std_internal_symbol]
6050
// unsafe fn __rg_oom(size: usize, align: usize) -> ! {
6151
// handler(core::alloc::Layout::from_size_align_unchecked(size, align))
6252
// }
63-
fn generate_handler(cx: &ExtCtxt<'_>, handler: Ident, span: Span, sig_span: Span) -> Stmt {
64-
let usize = cx.path_ident(span, Ident::new(sym::usize, span));
65-
let ty_usize = cx.ty_path(usize);
66-
let size = Ident::from_str_and_span("size", span);
67-
let align = Ident::from_str_and_span("align", span);
68-
69-
let layout_new = cx.std_path(&[sym::alloc, sym::Layout, sym::from_size_align_unchecked]);
70-
let layout_new = cx.expr_path(cx.path(span, layout_new));
71-
let layout = cx.expr_call(
72-
span,
73-
layout_new,
74-
thin_vec![cx.expr_ident(span, size), cx.expr_ident(span, align)],
75-
);
76-
77-
let call = cx.expr_call_ident(sig_span, handler, thin_vec![layout]);
78-
79-
let never = ast::FnRetTy::Ty(cx.ty(span, TyKind::Never));
80-
let params = thin_vec![cx.param(span, size, ty_usize.clone()), cx.param(span, align, ty_usize)];
81-
let decl = cx.fn_decl(params, never);
82-
let header = FnHeader { unsafety: Unsafe::Yes(span), ..FnHeader::default() };
83-
let sig = FnSig { decl, header, span: span };
84-
85-
let body = Some(cx.block_expr(call));
86-
let kind = ItemKind::Fn(Box::new(Fn {
87-
defaultness: ast::Defaultness::Final,
88-
sig,
89-
generics: Generics::default(),
90-
body,
91-
}));
92-
93-
let attrs = thin_vec![cx.attr_word(sym::rustc_std_internal_symbol, span)];
94-
95-
let item = cx.item(span, Ident::from_str_and_span("__rg_oom", span), attrs, kind);
96-
cx.stmt_item(sig_span, item)
97-
}
53+
//fn generate_handler(cx: &ExtCtxt<'_>, handler: Ident, span: Span, sig_span: Span) -> Stmt {
54+
// let usize = cx.path_ident(span, Ident::new(sym::usize, span));
55+
// let ty_usize = cx.ty_path(usize);
56+
// let size = Ident::from_str_and_span("size", span);
57+
// let align = Ident::from_str_and_span("align", span);
58+
//
59+
// let layout_new = cx.std_path(&[sym::alloc, sym::Layout, sym::from_size_align_unchecked]);
60+
// let layout_new = cx.expr_path(cx.path(span, layout_new));
61+
// let layout = cx.expr_call(
62+
// span,
63+
// layout_new,
64+
// thin_vec![cx.expr_ident(span, size), cx.expr_ident(span, align)],
65+
// );
66+
//
67+
// let call = cx.expr_call_ident(sig_span, handler, thin_vec![layout]);
68+
//
69+
// let never = ast::FnRetTy::Ty(cx.ty(span, TyKind::Never));
70+
// let params = thin_vec![cx.param(span, size, ty_usize.clone()), cx.param(span, align, ty_usize)];
71+
// let decl = cx.fn_decl(params, never);
72+
// let header = FnHeader { unsafety: Unsafe::Yes(span), ..FnHeader::default() };
73+
// let sig = FnSig { decl, header, span: span };
74+
//
75+
// let body = Some(cx.block_expr(call));
76+
// let kind = ItemKind::Fn(Box::new(Fn {
77+
// defaultness: ast::Defaultness::Final,
78+
// sig,
79+
// generics: Generics::default(),
80+
// body,
81+
// }));
82+
//
83+
// let attrs = thin_vec![cx.attr_word(sym::rustc_std_internal_symbol, span)];
84+
//
85+
// let item = cx.item(span, Ident::from_str_and_span("__rg_oom", span), attrs, kind);
86+
// cx.stmt_item(sig_span, item)
87+
//}

compiler/rustc_builtin_macros/src/errors.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,13 @@ pub(crate) struct AllocMustStatics {
164164
pub(crate) span: Span,
165165
}
166166

167+
#[derive(Diagnostic)]
168+
#[diag(builtin_macros_autodiff)]
169+
pub(crate) struct AutoDiffInvalidApplication {
170+
#[primary_span]
171+
pub(crate) span: Span,
172+
}
173+
167174
#[derive(Diagnostic)]
168175
#[diag(builtin_macros_concat_bytes_invalid)]
169176
pub(crate) struct ConcatBytesInvalid {

0 commit comments

Comments
 (0)