4
4
//use crate::util::check_builtin_macro_attribute;
5
5
//use crate::util::check_autodiff;
6
6
7
- use std:: string:: String ;
8
7
use crate :: errors;
8
+ use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
9
9
use rustc_ast:: ptr:: P ;
10
- use rustc_ast:: { BindingAnnotation , ByRef } ;
11
- use rustc_ast:: { self as ast, FnHeader , FnSig , Generics , StmtKind , NestedMetaItem , MetaItemKind } ;
12
- use rustc_ast:: { Fn , ItemKind , Stmt , TyKind , Unsafe , PatKind } ;
10
+ use rustc_ast:: token:: { Token , TokenKind } ;
13
11
use rustc_ast:: tokenstream:: * ;
12
+ use rustc_ast:: { self as ast, FnHeader , FnSig , Generics , MetaItemKind , NestedMetaItem , StmtKind } ;
13
+ use rustc_ast:: { BindingAnnotation , ByRef } ;
14
+ use rustc_ast:: { Fn , ItemKind , PatKind , Stmt , TyKind , Unsafe } ;
14
15
use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
15
16
use rustc_span:: symbol:: { kw, sym, Ident } ;
16
17
use rustc_span:: Span ;
17
- use thin_vec:: { thin_vec, ThinVec } ;
18
18
use rustc_span:: Symbol ;
19
- use rustc_ast :: expand :: autodiff_attrs :: { AutoDiffAttrs , DiffActivity , DiffMode } ;
20
- use rustc_ast :: token :: { Token , TokenKind } ;
19
+ use std :: string :: String ;
20
+ use thin_vec :: { thin_vec , ThinVec } ;
21
21
22
22
fn first_ident ( x : & NestedMetaItem ) -> rustc_span:: symbol:: Ident {
23
23
let segments = & x. meta_item ( ) . unwrap ( ) . path . segments ;
@@ -36,9 +36,7 @@ pub fn expand(
36
36
let meta_item_vec: ThinVec < NestedMetaItem > = match meta_item. kind {
37
37
ast:: MetaItemKind :: List ( ref vec) => vec. clone ( ) ,
38
38
_ => {
39
- ecx. sess
40
- . dcx ( )
41
- . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
39
+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
42
40
return vec ! [ item] ;
43
41
}
44
42
} ;
@@ -52,18 +50,19 @@ pub fn expand(
52
50
{
53
51
( item, sig. decl . output . has_ret ( ) , sig, ecx. with_call_site_ctxt ( sig. span ) )
54
52
} else {
55
- ecx. sess
56
- . dcx ( )
57
- . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
53
+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
58
54
return vec ! [ item] ;
59
55
} ;
60
56
// create TokenStream from vec elemtents:
61
57
// meta_item doesn't have a .tokens field
62
- let ts: Vec < Token > = meta_item_vec. clone ( ) [ 1 ..] . iter ( ) . map ( |x| {
63
- let val = first_ident ( x) ;
64
- let t = Token :: from_ast_ident ( val) ;
65
- t
66
- } ) . collect ( ) ;
58
+ let ts: Vec < Token > = meta_item_vec. clone ( ) [ 1 ..]
59
+ . iter ( )
60
+ . map ( |x| {
61
+ let val = first_ident ( x) ;
62
+ let t = Token :: from_ast_ident ( val) ;
63
+ t
64
+ } )
65
+ . collect ( ) ;
67
66
let comma: Token = Token :: new ( TokenKind :: Comma , Span :: default ( ) ) ;
68
67
let mut ts: Vec < TokenTree > = vec ! [ ] ;
69
68
for t in meta_item_vec. clone ( ) [ 1 ..] . iter ( ) {
@@ -80,7 +79,18 @@ pub fn expand(
80
79
81
80
let ( d_sig, old_names, new_args, idents) = gen_enzyme_decl ( ecx, & sig, & x, span, sig_span) ;
82
81
let new_decl_span = d_sig. span ;
83
- let d_body = gen_enzyme_body ( ecx, primal, & old_names, & new_args, span, sig_span, new_decl_span, & sig, & d_sig, idents) ;
82
+ let d_body = gen_enzyme_body (
83
+ ecx,
84
+ primal,
85
+ & old_names,
86
+ & new_args,
87
+ span,
88
+ sig_span,
89
+ new_decl_span,
90
+ & sig,
91
+ & d_sig,
92
+ idents,
93
+ ) ;
84
94
let d_ident = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) . path . segments [ 0 ] . ident ;
85
95
86
96
// The first element of it is the name of the function to be generated
@@ -90,7 +100,8 @@ pub fn expand(
90
100
generics : Generics :: default ( ) ,
91
101
body : Some ( d_body) ,
92
102
} ) ) ;
93
- let mut rustc_ad_attr = P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
103
+ let mut rustc_ad_attr =
104
+ P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
94
105
let mut attr: ast:: Attribute = ast:: Attribute {
95
106
kind : ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) ,
96
107
id : ast:: AttrId :: from_u32 ( 0 ) ,
@@ -136,27 +147,33 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
136
147
ty
137
148
}
138
149
139
-
140
150
// The body of our generated functions will consist of three black_Box calls.
141
151
// The first will call the primal function with the original arguments.
142
152
// The second will just take the shadow arguments.
143
153
// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
144
154
// (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
145
- fn gen_enzyme_body ( ecx : & ExtCtxt < ' _ > , primal : Ident , old_names : & [ String ] , new_names : & [ String ] , span : Span , sig_span : Span , new_decl_span : Span , sig : & ast:: FnSig , d_sig : & ast:: FnSig , idents : Vec < Ident > ) -> P < ast:: Block > {
155
+ fn gen_enzyme_body (
156
+ ecx : & ExtCtxt < ' _ > ,
157
+ primal : Ident ,
158
+ old_names : & [ String ] ,
159
+ new_names : & [ String ] ,
160
+ span : Span ,
161
+ sig_span : Span ,
162
+ new_decl_span : Span ,
163
+ sig : & ast:: FnSig ,
164
+ d_sig : & ast:: FnSig ,
165
+ idents : Vec < Ident > ,
166
+ ) -> P < ast:: Block > {
146
167
let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
147
168
let zeroed_path = ecx. std_path ( & [ Symbol :: intern ( "mem" ) , Symbol :: intern ( "zeroed" ) ] ) ;
148
169
let empty_loop_block = ecx. block ( span, ThinVec :: new ( ) ) ;
149
170
let loop_expr = ecx. expr_loop ( span, empty_loop_block) ;
150
171
151
-
152
172
let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
153
173
let zeroed_call_expr = ecx. expr_path ( ecx. path ( span, zeroed_path) ) ;
154
174
155
- let mem_zeroed_call: Stmt = ecx. stmt_expr ( ecx. expr_call (
156
- span,
157
- zeroed_call_expr. clone ( ) ,
158
- thin_vec ! [ ] ,
159
- ) ) ;
175
+ let mem_zeroed_call: Stmt =
176
+ ecx. stmt_expr ( ecx. expr_call ( span, zeroed_call_expr. clone ( ) , thin_vec ! [ ] ) ) ;
160
177
let unsafe_block_with_zeroed_call: P < ast:: Expr > = ecx. expr_block ( P ( ast:: Block {
161
178
stmts : thin_vec ! [ mem_zeroed_call] ,
162
179
id : ast:: DUMMY_NODE_ID ,
@@ -167,19 +184,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
167
184
} ) ) ;
168
185
let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
169
186
// create ::core::hint::black_box(array(arr));
170
- let black_box0 = ecx. expr_call (
171
- new_decl_span,
172
- blackbox_call_expr. clone ( ) ,
173
- thin_vec ! [ primal_call. clone( ) ] ,
174
- ) ;
187
+ let black_box0 =
188
+ ecx. expr_call ( new_decl_span, blackbox_call_expr. clone ( ) , thin_vec ! [ primal_call. clone( ) ] ) ;
175
189
176
190
// create ::core::hint::black_box(grad_arr, tang_y));
177
191
let black_box1 = ecx. expr_call (
178
192
sig_span,
179
193
blackbox_call_expr. clone ( ) ,
180
- new_names. iter ( ) . map ( |arg| {
181
- ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( arg) ) )
182
- } ) . collect ( ) ,
194
+ new_names
195
+ . iter ( )
196
+ . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( arg) ) ) )
197
+ . collect ( ) ,
183
198
) ;
184
199
185
200
// create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
@@ -189,7 +204,6 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
189
204
thin_vec ! [ unsafe_block_with_zeroed_call. clone( ) ] ,
190
205
) ;
191
206
192
-
193
207
let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
194
208
body. stmts . push ( ecx. stmt_semi ( primal_call) ) ;
195
209
body. stmts . push ( ecx. stmt_semi ( black_box0) ) ;
@@ -199,16 +213,16 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
199
213
body
200
214
}
201
215
202
- fn gen_primal_call ( ecx : & ExtCtxt < ' _ > , span : Span , primal : Ident , sig : & ast:: FnSig , idents : Vec < Ident > ) -> P < ast:: Expr > {
216
+ fn gen_primal_call (
217
+ ecx : & ExtCtxt < ' _ > ,
218
+ span : Span ,
219
+ primal : Ident ,
220
+ sig : & ast:: FnSig ,
221
+ idents : Vec < Ident > ,
222
+ ) -> P < ast:: Expr > {
203
223
let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
204
- let args = idents. iter ( ) . map ( |arg| {
205
- ecx. expr_path ( ecx. path_ident ( span, * arg) )
206
- } ) . collect ( ) ;
207
- let primal_call = ecx. expr_call (
208
- span,
209
- primal_call_expr,
210
- args,
211
- ) ;
224
+ let args = idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
225
+ let primal_call = ecx. expr_call ( span, primal_call_expr, args) ;
212
226
primal_call
213
227
}
214
228
@@ -218,8 +232,13 @@ fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSi
218
232
// zero-initialized by Enzyme). Active arguments are not handled yet.
219
233
// Each argument of the primal function (and the return type if existing) must be annotated with an
220
234
// activity.
221
- fn gen_enzyme_decl ( _ecx : & ExtCtxt < ' _ > , sig : & ast:: FnSig , x : & AutoDiffAttrs , span : Span , _sig_span : Span )
222
- -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
235
+ fn gen_enzyme_decl (
236
+ _ecx : & ExtCtxt < ' _ > ,
237
+ sig : & ast:: FnSig ,
238
+ x : & AutoDiffAttrs ,
239
+ span : Span ,
240
+ _sig_span : Span ,
241
+ ) -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
223
242
assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
224
243
assert ! ( sig. decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
225
244
let mut d_decl = sig. decl . clone ( ) ;
@@ -253,17 +272,16 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
253
272
shadow_arg. pat = P ( ast:: Pat {
254
273
// TODO: Check id
255
274
id : ast:: DUMMY_NODE_ID ,
256
- kind : PatKind :: Ident ( BindingAnnotation :: NONE ,
257
- ident,
258
- None ,
259
- ) ,
275
+ kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
260
276
span : shadow_arg. pat . span ,
261
277
tokens : shadow_arg. pat . tokens . clone ( ) ,
262
278
} ) ;
263
279
//idents.push(ident);
264
280
d_inputs. push ( shadow_arg) ;
265
281
}
266
- _ => { dbg ! ( & activity) ; } ,
282
+ _ => {
283
+ dbg ! ( & activity) ;
284
+ }
267
285
}
268
286
}
269
287
@@ -285,10 +303,7 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
285
303
ty : ty. clone ( ) ,
286
304
pat : P ( ast:: Pat {
287
305
id : ast:: DUMMY_NODE_ID ,
288
- kind : PatKind :: Ident ( BindingAnnotation :: NONE ,
289
- ident,
290
- None ,
291
- ) ,
306
+ kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
292
307
span : ty. span ,
293
308
tokens : None ,
294
309
} ) ,
@@ -302,10 +317,6 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
302
317
}
303
318
}
304
319
d_decl. inputs = d_inputs. into ( ) ;
305
- let d_sig = FnSig {
306
- header : sig. header . clone ( ) ,
307
- decl : d_decl,
308
- span,
309
- } ;
320
+ let d_sig = FnSig { header : sig. header . clone ( ) , decl : d_decl, span } ;
310
321
( d_sig, old_names, new_inputs, idents)
311
322
}
0 commit comments