1
- #![ allow( unused) ]
2
-
3
- use crate :: errors;
1
+ #![ allow( unused_imports) ]
4
2
//use crate::util::check_builtin_macro_attribute;
5
3
//use crate::util::check_autodiff;
6
4
5
+ use std:: string:: String ;
6
+ use crate :: errors;
7
7
use rustc_ast:: ptr:: P ;
8
- use rustc_ast:: { self as ast, FnHeader , FnSig , Generics , StmtKind } ;
9
- use rustc_ast:: { Fn , ItemKind , Stmt , TyKind , Unsafe } ;
8
+ use rustc_ast:: { BindingAnnotation , ByRef } ;
9
+ use rustc_ast:: { self as ast, FnHeader , FnSig , Generics , StmtKind , NestedMetaItem , MetaItemKind } ;
10
+ use rustc_ast:: { Fn , ItemKind , Stmt , TyKind , Unsafe , PatKind } ;
10
11
use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
11
12
use rustc_span:: symbol:: { kw, sym, Ident } ;
12
13
use rustc_span:: Span ;
13
14
use thin_vec:: { thin_vec, ThinVec } ;
14
15
use rustc_span:: Symbol ;
16
+ use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
15
17
16
18
pub fn expand (
17
19
ecx : & mut ExtCtxt < ' _ > ,
@@ -20,69 +22,183 @@ pub fn expand(
20
22
item : Annotatable ,
21
23
) -> Vec < Annotatable > {
22
24
//check_builtin_macro_attribute(ecx, meta_item, sym::alloc_error_handler);
23
- //check_builtin_macro_attribute(ecx, meta_item, sym::autodiff);
24
25
25
- dbg ! ( & meta_item) ;
26
+ let meta_item_vec: ThinVec < NestedMetaItem > = match meta_item. kind {
27
+ ast:: MetaItemKind :: List ( ref vec) => vec. clone ( ) ,
28
+ _ => {
29
+ ecx. sess
30
+ . dcx ( )
31
+ . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
32
+ return vec ! [ item] ;
33
+ }
34
+ } ;
26
35
let input = item. clone ( ) ;
27
36
let orig_item: P < ast:: Item > = item. clone ( ) . expect_item ( ) ;
28
37
let mut d_item: P < ast:: Item > = item. clone ( ) . expect_item ( ) ;
38
+ let primal = orig_item. ident . clone ( ) ;
29
39
30
- // Allow using `#[autodiff(...)]` on a Fn
31
- let ( fn_item, _ty_span ) = if let Annotatable :: Item ( item) = & item
40
+ // Allow using `#[autodiff(...)]` only on a Fn
41
+ let ( fn_item, has_ret , sig , sig_span ) = if let Annotatable :: Item ( item) = & item
32
42
&& let ItemKind :: Fn ( box ast:: Fn { sig, .. } ) = & item. kind
33
43
{
34
- dbg ! ( & item) ;
35
- ( item, ecx. with_def_site_ctxt ( sig. span ) )
44
+ ( item, sig. decl . output . has_ret ( ) , sig, ecx. with_def_site_ctxt ( sig. span ) )
36
45
} else {
37
46
ecx. sess
38
47
. dcx ( )
39
48
. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
40
49
return vec ! [ input] ;
41
50
} ;
42
- let _x: & ItemKind = & fn_item. kind ;
43
- d_item. ident . name =
44
- Symbol :: intern ( format ! ( "d_{}" , fn_item. ident. name) . as_str ( ) ) ;
51
+ let x: AutoDiffAttrs = AutoDiffAttrs :: from_ast ( & meta_item_vec, has_ret) ;
52
+ dbg ! ( & x) ;
53
+ let span = ecx. with_def_site_ctxt ( fn_item. span ) ;
54
+
55
+ let ( d_decl, old_names, new_args) = gen_enzyme_decl ( ecx, & sig. decl , & x, span, sig_span) ;
56
+ let d_body = gen_enzyme_body ( ecx, primal, & old_names, & new_args, span, sig_span) ;
57
+ let meta_item_name = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) ;
58
+ d_item. ident = meta_item_name. path . segments [ 0 ] . ident ;
59
+ // update d_item
60
+ if let ItemKind :: Fn ( box ast:: Fn { sig, body, .. } ) = & mut d_item. kind {
61
+ * sig. decl = d_decl;
62
+ * body = Some ( d_body) ;
63
+ } else {
64
+ ecx. sess
65
+ . dcx ( )
66
+ . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
67
+ return vec ! [ input] ;
68
+ }
69
+
45
70
let orig_annotatable = Annotatable :: Item ( orig_item. clone ( ) ) ;
46
71
let d_annotatable = Annotatable :: Item ( d_item. clone ( ) ) ;
47
72
return vec ! [ orig_annotatable, d_annotatable] ;
48
73
}
49
74
50
- // #[rustc_std_internal_symbol]
51
- // unsafe fn __rg_oom(size: usize, align: usize) -> ! {
52
- // handler(core::alloc::Layout::from_size_align_unchecked(size, align))
53
- // }
54
- //fn generate_handler(cx: &ExtCtxt<'_>, handler: Ident, span: Span, sig_span: Span) -> Stmt {
55
- // let usize = cx.path_ident(span, Ident::new(sym::usize, span));
56
- // let ty_usize = cx.ty_path(usize);
57
- // let size = Ident::from_str_and_span("size", span);
58
- // let align = Ident::from_str_and_span("align", span);
59
- //
60
- // let layout_new = cx.std_path(&[sym::alloc, sym::Layout, sym::from_size_align_unchecked]);
61
- // let layout_new = cx.expr_path(cx.path(span, layout_new));
62
- // let layout = cx.expr_call(
63
- // span,
64
- // layout_new,
65
- // thin_vec![cx.expr_ident(span, size), cx.expr_ident(span, align)],
66
- // );
67
- //
68
- // let call = cx.expr_call_ident(sig_span, handler, thin_vec![layout]);
69
- //
70
- // let never = ast::FnRetTy::Ty(cx.ty(span, TyKind::Never));
71
- // let params = thin_vec![cx.param(span, size, ty_usize.clone()), cx.param(span, align, ty_usize)];
72
- // let decl = cx.fn_decl(params, never);
73
- // let header = FnHeader { unsafety: Unsafe::Yes(span), ..FnHeader::default() };
74
- // let sig = FnSig { decl, header, span: span };
75
- //
76
- // let body = Some(cx.block_expr(call));
77
- // let kind = ItemKind::Fn(Box::new(Fn {
78
- // defaultness: ast::Defaultness::Final,
79
- // sig,
80
- // generics: Generics::default(),
81
- // body,
82
- // }));
83
- //
84
- // let attrs = thin_vec![cx.attr_word(sym::rustc_std_internal_symbol, span)];
85
- //
86
- // let item = cx.item(span, Ident::from_str_and_span("__rg_oom", span), attrs, kind);
87
- // cx.stmt_item(sig_span, item)
88
- //}
75
+ // shadow arguments must be mutable references or ptrs, because Enzyme will write into them.
76
+ fn assure_mut_ref ( ty : & ast:: Ty ) -> ast:: Ty {
77
+ let mut ty = ty. clone ( ) ;
78
+ match ty. kind {
79
+ TyKind :: Ptr ( ref mut mut_ty) => {
80
+ mut_ty. mutbl = ast:: Mutability :: Mut ;
81
+ }
82
+ TyKind :: Ref ( _, ref mut mut_ty) => {
83
+ mut_ty. mutbl = ast:: Mutability :: Mut ;
84
+ }
85
+ _ => {
86
+ panic ! ( "unsupported type: {:?}" , ty) ;
87
+ }
88
+ }
89
+ ty
90
+ }
91
+
92
+
93
+ // The body of our generated functions will consist of three black_Box calls.
94
+ // The first will call the primal function with the original arguments.
95
+ // The second will just take the shadow arguments.
96
+ // The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
97
+ // (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
98
+ fn gen_enzyme_body ( ecx : & ExtCtxt < ' _ > , primal : Ident , old_names : & [ String ] , new_names : & [ String ] , span : Span , sig_span : Span ) -> P < ast:: Block > {
99
+ let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
100
+ let zeroed_path = ecx. std_path ( & [ Symbol :: intern ( "mem" ) , Symbol :: intern ( "zeroed" ) ] ) ;
101
+
102
+ let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
103
+ let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
104
+ let zeroed_call_expr = ecx. expr_path ( ecx. path ( span, zeroed_path) ) ;
105
+
106
+ let mem_zeroed_call: Stmt = ecx. stmt_expr ( ecx. expr_call (
107
+ span,
108
+ zeroed_call_expr. clone ( ) ,
109
+ thin_vec ! [ ] ,
110
+ ) ) ;
111
+ let unsafe_block_with_zeroed_call: P < ast:: Expr > = ecx. expr_block ( P ( ast:: Block {
112
+ stmts : thin_vec ! [ mem_zeroed_call] ,
113
+ id : ast:: DUMMY_NODE_ID ,
114
+ rules : ast:: BlockCheckMode :: Unsafe ( ast:: UserProvided ) ,
115
+ span : sig_span,
116
+ tokens : None ,
117
+ could_be_bare_literal : false ,
118
+ } ) ) ;
119
+ // create ::core::hint::black_box(array(arr));
120
+ let _primal_call = ecx. expr_call (
121
+ span,
122
+ primal_call_expr. clone ( ) ,
123
+ old_names. iter ( ) . map ( |name| {
124
+ ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( name) ) )
125
+ } ) . collect ( ) ,
126
+ ) ;
127
+
128
+ // create ::core::hint::black_box(grad_arr, tang_y));
129
+ let black_box1 = ecx. expr_call (
130
+ sig_span,
131
+ blackbox_call_expr. clone ( ) ,
132
+ new_names. iter ( ) . map ( |arg| {
133
+ ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( arg) ) )
134
+ } ) . collect ( ) ,
135
+ ) ;
136
+
137
+ // create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
138
+ let black_box2 = ecx. expr_call (
139
+ sig_span,
140
+ blackbox_call_expr. clone ( ) ,
141
+ thin_vec ! [ unsafe_block_with_zeroed_call. clone( ) ] ,
142
+ ) ;
143
+
144
+ let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
145
+ body. stmts . push ( ecx. stmt_expr ( black_box1) ) ;
146
+ body. stmts . push ( ecx. stmt_expr ( black_box2) ) ;
147
+ body
148
+ }
149
+
150
+ // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
151
+ // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
152
+ // Active arguments must be scalars. Their shadow argument is added to the return type (and will be
153
+ // zero-initialized by Enzyme). Active arguments are not handled yet.
154
+ // Each argument of the primal function (and the return type if existing) must be annotated with an
155
+ // activity.
156
+ fn gen_enzyme_decl ( _ecx : & ExtCtxt < ' _ > , decl : & ast:: FnDecl , x : & AutoDiffAttrs , _span : Span , _sig_span : Span )
157
+ -> ( ast:: FnDecl , Vec < String > , Vec < String > ) {
158
+ assert ! ( decl. inputs. len( ) == x. input_activity. len( ) ) ;
159
+ assert ! ( decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
160
+ let mut d_decl = decl. clone ( ) ;
161
+ let mut d_inputs = Vec :: new ( ) ;
162
+ let mut new_inputs = Vec :: new ( ) ;
163
+ let mut old_names = Vec :: new ( ) ;
164
+ for ( arg, activity) in decl. inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
165
+ dbg ! ( & arg) ;
166
+ d_inputs. push ( arg. clone ( ) ) ;
167
+ match activity {
168
+ DiffActivity :: Duplicated => {
169
+ let mut shadow_arg = arg. clone ( ) ;
170
+ shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
171
+ // adjust name depending on mode
172
+ let old_name = if let PatKind :: Ident ( _, ident, _) = shadow_arg. pat . kind {
173
+ ident. name
174
+ } else {
175
+ dbg ! ( & shadow_arg. pat) ;
176
+ panic ! ( "not an ident?" ) ;
177
+ } ;
178
+ old_names. push ( old_name. to_string ( ) ) ;
179
+ let name: String = match x. mode {
180
+ DiffMode :: Reverse => format ! ( "d{}" , old_name) ,
181
+ DiffMode :: Forward => format ! ( "b{}" , old_name) ,
182
+ _ => panic ! ( "unsupported mode: {}" , old_name) ,
183
+ } ;
184
+ dbg ! ( & name) ;
185
+ new_inputs. push ( name. clone ( ) ) ;
186
+ shadow_arg. pat = P ( ast:: Pat {
187
+ // TODO: Check id
188
+ id : ast:: DUMMY_NODE_ID ,
189
+ kind : PatKind :: Ident ( BindingAnnotation :: NONE ,
190
+ Ident :: from_str_and_span ( & name, shadow_arg. pat . span ) ,
191
+ None ,
192
+ ) ,
193
+ span : shadow_arg. pat . span ,
194
+ tokens : shadow_arg. pat . tokens . clone ( ) ,
195
+ } ) ;
196
+
197
+ d_inputs. push ( shadow_arg) ;
198
+ }
199
+ _ => { } ,
200
+ }
201
+ }
202
+ d_decl. inputs = d_inputs. into ( ) ;
203
+ ( d_decl, old_names, new_inputs)
204
+ }
0 commit comments