5
5
//use crate::util::check_autodiff;
6
6
7
7
use crate :: errors;
8
+ use rustc_ast:: FnRetTy ;
8
9
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
9
10
use rustc_ast:: ptr:: P ;
10
11
use rustc_ast:: token:: { Token , TokenKind } ;
@@ -41,7 +42,6 @@ pub fn expand(
41
42
}
42
43
} ;
43
44
let mut orig_item: P < ast:: Item > = item. clone ( ) . expect_item ( ) ;
44
- //dbg!(&orig_item.tokens);
45
45
let primal = orig_item. ident . clone ( ) ;
46
46
47
47
// Allow using `#[autodiff(...)]` only on a Fn
@@ -77,7 +77,7 @@ pub fn expand(
77
77
dbg ! ( & x) ;
78
78
let span = ecx. with_def_site_ctxt ( expand_span) ;
79
79
80
- let ( d_sig, old_names, new_args, idents) = gen_enzyme_decl ( ecx , & sig, & x, span, sig_span ) ;
80
+ let ( d_sig, old_names, new_args, idents) = gen_enzyme_decl ( & sig, & x, span) ;
81
81
let new_decl_span = d_sig. span ;
82
82
let d_body = gen_enzyme_body (
83
83
ecx,
@@ -147,11 +147,11 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
147
147
ty
148
148
}
149
149
150
- // The body of our generated functions will consist of three black_Box calls.
150
+ // The body of our generated functions will consist of two black_Box calls.
151
151
// The first will call the primal function with the original arguments.
152
- // The second will just take the shadow arguments.
153
- // The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
154
- // (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
152
+ // The second will just take a tuple containing the new arguments.
153
+ // This way we surpress rustc from optimizing any argument away.
154
+ // The last line will 'loop {}', to match the return type of the new function
155
155
fn gen_enzyme_body (
156
156
ecx : & ExtCtxt < ' _ > ,
157
157
primal : Ident ,
@@ -184,31 +184,25 @@ fn gen_enzyme_body(
184
184
} ) ) ;
185
185
let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
186
186
// create ::core::hint::black_box(array(arr));
187
- let black_box0 =
187
+ let black_box_primal_call =
188
188
ecx. expr_call ( new_decl_span, blackbox_call_expr. clone ( ) , thin_vec ! [ primal_call. clone( ) ] ) ;
189
189
190
- // create ::core::hint::black_box(grad_arr, tang_y));
191
- let black_box1 = ecx. expr_call (
192
- sig_span,
193
- blackbox_call_expr. clone ( ) ,
194
- new_names
195
- . iter ( )
196
- . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( arg) ) ) )
197
- . collect ( ) ,
198
- ) ;
190
+ // create ::core::hint::black_box((grad_arr, tang_y));
191
+ let tup_args = new_names
192
+ . iter ( )
193
+ . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( arg) ) ) )
194
+ . collect ( ) ;
199
195
200
- // create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
201
- let black_box2 = ecx. expr_call (
196
+ let black_box_remaining_args = ecx. expr_call (
202
197
sig_span,
203
198
blackbox_call_expr. clone ( ) ,
204
- thin_vec ! [ unsafe_block_with_zeroed_call . clone ( ) ] ,
199
+ thin_vec ! [ ecx . expr_tuple ( sig_span , tup_args ) ] ,
205
200
) ;
206
201
207
202
let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
208
203
body. stmts . push ( ecx. stmt_semi ( primal_call) ) ;
209
- body. stmts . push ( ecx. stmt_semi ( black_box0) ) ;
210
- body. stmts . push ( ecx. stmt_semi ( black_box1) ) ;
211
- //body.stmts.push(ecx.stmt_semi(black_box2));
204
+ body. stmts . push ( ecx. stmt_semi ( black_box_primal_call) ) ;
205
+ body. stmts . push ( ecx. stmt_semi ( black_box_remaining_args) ) ;
212
206
body. stmts . push ( ecx. stmt_expr ( loop_expr) ) ;
213
207
body
214
208
}
@@ -233,11 +227,9 @@ fn gen_primal_call(
233
227
// Each argument of the primal function (and the return type if existing) must be annotated with an
234
228
// activity.
235
229
fn gen_enzyme_decl (
236
- _ecx : & ExtCtxt < ' _ > ,
237
230
sig : & ast:: FnSig ,
238
231
x : & AutoDiffAttrs ,
239
232
span : Span ,
240
- _sig_span : Span ,
241
233
) -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
242
234
assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
243
235
assert ! ( sig. decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
@@ -246,15 +238,19 @@ fn gen_enzyme_decl(
246
238
let mut new_inputs = Vec :: new ( ) ;
247
239
let mut old_names = Vec :: new ( ) ;
248
240
let mut idents = Vec :: new ( ) ;
241
+ let mut act_ret = ThinVec :: new ( ) ;
249
242
for ( arg, activity) in sig. decl . inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
250
243
d_inputs. push ( arg. clone ( ) ) ;
251
244
match activity {
245
+ DiffActivity :: Active => {
246
+ assert ! ( x. mode == DiffMode :: Reverse ) ;
247
+ act_ret. push ( arg. ty . clone ( ) ) ;
248
+ }
252
249
DiffActivity :: Duplicated | DiffActivity :: Dual => {
253
250
let mut shadow_arg = arg. clone ( ) ;
254
251
shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
255
252
// adjust name depending on mode
256
- let old_name = if let PatKind :: Ident ( _, ident, _) = shadow_arg. pat . kind {
257
- idents. push ( ident. clone ( ) ) ;
253
+ let old_name = if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
258
254
ident. name
259
255
} else {
260
256
dbg ! ( & shadow_arg. pat) ;
@@ -276,47 +272,72 @@ fn gen_enzyme_decl(
276
272
span : shadow_arg. pat . span ,
277
273
tokens : shadow_arg. pat . tokens . clone ( ) ,
278
274
} ) ;
279
- //idents.push(ident);
280
275
d_inputs. push ( shadow_arg) ;
281
276
}
282
277
_ => {
283
278
dbg ! ( & activity) ;
284
279
}
285
280
}
281
+ if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
282
+ idents. push ( ident. clone ( ) ) ;
283
+ } else {
284
+ panic ! ( "not an ident?" ) ;
285
+ }
286
286
}
287
287
288
288
// If we return a scalar in the primal and the scalar is active,
289
289
// then add it as last arg to the inputs.
290
- if x. mode == DiffMode :: Reverse {
291
- match x. ret_activity {
292
- DiffActivity :: Active => {
293
- let ty = match d_decl. output {
294
- rustc_ast:: FnRetTy :: Ty ( ref ty) => ty. clone ( ) ,
295
- rustc_ast:: FnRetTy :: Default ( span) => {
296
- panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
297
- }
298
- } ;
299
- let name = "dret" . to_string ( ) ;
300
- let ident = Ident :: from_str_and_span ( & name, ty. span ) ;
301
- let shadow_arg = ast:: Param {
302
- attrs : ThinVec :: new ( ) ,
303
- ty : ty. clone ( ) ,
304
- pat : P ( ast:: Pat {
305
- id : ast:: DUMMY_NODE_ID ,
306
- kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
307
- span : ty. span ,
308
- tokens : None ,
309
- } ) ,
290
+ if let DiffMode :: Reverse = x. mode {
291
+ if let DiffActivity :: Active = x. ret_activity {
292
+ let ty = match d_decl. output {
293
+ FnRetTy :: Ty ( ref ty) => ty. clone ( ) ,
294
+ FnRetTy :: Default ( span) => {
295
+ panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
296
+ }
297
+ } ;
298
+ let name = "dret" . to_string ( ) ;
299
+ let ident = Ident :: from_str_and_span ( & name, ty. span ) ;
300
+ let shadow_arg = ast:: Param {
301
+ attrs : ThinVec :: new ( ) ,
302
+ ty : ty. clone ( ) ,
303
+ pat : P ( ast:: Pat {
310
304
id : ast:: DUMMY_NODE_ID ,
305
+ kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
311
306
span : ty. span ,
312
- is_placeholder : false ,
313
- } ;
314
- d_inputs. push ( shadow_arg) ;
315
- }
316
- _ => { }
307
+ tokens : None ,
308
+ } ) ,
309
+ id : ast:: DUMMY_NODE_ID ,
310
+ span : ty. span ,
311
+ is_placeholder : false ,
312
+ } ;
313
+ d_inputs. push ( shadow_arg) ;
314
+ new_inputs. push ( name) ;
317
315
}
318
316
}
319
317
d_decl. inputs = d_inputs. into ( ) ;
318
+
319
+ // If we have an active input scalar, add it's gradient to the
320
+ // return type. This might require changing the return type to a
321
+ // tuple.
322
+ if act_ret. len ( ) > 0 {
323
+ let mut ret_ty = match d_decl. output {
324
+ FnRetTy :: Ty ( ref ty) => {
325
+ act_ret. insert ( 0 , ty. clone ( ) ) ;
326
+ let kind = TyKind :: Tup ( act_ret) ;
327
+ P ( rustc_ast:: Ty { kind, id : ty. id , span : ty. span , tokens : None } )
328
+ }
329
+ FnRetTy :: Default ( span) => {
330
+ if act_ret. len ( ) == 1 {
331
+ act_ret[ 0 ] . clone ( )
332
+ } else {
333
+ let kind = TyKind :: Tup ( act_ret. iter ( ) . map ( |arg| arg. clone ( ) ) . collect ( ) ) ;
334
+ P ( rustc_ast:: Ty { kind, id : ast:: DUMMY_NODE_ID , span, tokens : None } )
335
+ }
336
+ }
337
+ } ;
338
+ d_decl. output = FnRetTy :: Ty ( ret_ty) ;
339
+ }
340
+
320
341
let d_sig = FnSig { header : sig. header . clone ( ) , decl : d_decl, span } ;
321
342
( d_sig, old_names, new_inputs, idents)
322
343
}
0 commit comments