@@ -90,23 +90,23 @@ pub fn expand(
90
90
generics : Generics :: default ( ) ,
91
91
body : Some ( d_body) ,
92
92
} ) ) ;
93
- let mut tmp = P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
93
+ let mut rustc_ad_attr = P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
94
94
let mut attr: ast:: Attribute = ast:: Attribute {
95
- kind : ast:: AttrKind :: Normal ( tmp . clone ( ) ) ,
95
+ kind : ast:: AttrKind :: Normal ( rustc_ad_attr . clone ( ) ) ,
96
96
id : ast:: AttrId :: from_u32 ( 0 ) ,
97
97
style : ast:: AttrStyle :: Outer ,
98
98
span : span,
99
99
} ;
100
- orig_item. attrs . push ( attr) ;
100
+ orig_item. attrs . push ( attr. clone ( ) ) ;
101
101
102
102
// Now update for d_fn
103
- tmp . item . args = rustc_ast:: AttrArgs :: Delimited ( rustc_ast:: DelimArgs {
103
+ rustc_ad_attr . item . args = rustc_ast:: AttrArgs :: Delimited ( rustc_ast:: DelimArgs {
104
104
dspan : DelimSpan :: dummy ( ) ,
105
105
delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
106
106
tokens : ts,
107
107
} ) ;
108
108
let mut attr2: ast:: Attribute = ast:: Attribute {
109
- kind : ast:: AttrKind :: Normal ( tmp ) ,
109
+ kind : ast:: AttrKind :: Normal ( rustc_ad_attr ) ,
110
110
id : ast:: AttrId :: from_u32 ( 0 ) ,
111
111
style : ast:: AttrStyle :: Outer ,
112
112
span : span,
@@ -165,12 +165,13 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
165
165
tokens : None ,
166
166
could_be_bare_literal : false ,
167
167
} ) ) ;
168
+ let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
168
169
// create ::core::hint::black_box(array(arr));
169
- // let black_box0 = ecx.expr_call(
170
- // new_decl_span,
171
- // blackbox_call_expr.clone(),
172
- // thin_vec![primal_call.clone()],
173
- // );
170
+ let black_box0 = ecx. expr_call (
171
+ new_decl_span,
172
+ blackbox_call_expr. clone ( ) ,
173
+ thin_vec ! [ primal_call. clone( ) ] ,
174
+ ) ;
174
175
175
176
// create ::core::hint::black_box(grad_arr, tang_y));
176
177
let black_box1 = ecx. expr_call (
@@ -188,26 +189,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
188
189
thin_vec ! [ unsafe_block_with_zeroed_call. clone( ) ] ,
189
190
) ;
190
191
191
- let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
192
192
193
193
let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
194
194
body. stmts . push ( ecx. stmt_semi ( primal_call) ) ;
195
- // body.stmts.push(ecx.stmt_expr (black_box0));
196
- // body.stmts.push(ecx.stmt_expr (black_box1));
197
- //body.stmts.push(ecx.stmt_expr (black_box2));
195
+ body. stmts . push ( ecx. stmt_semi ( black_box0) ) ;
196
+ body. stmts . push ( ecx. stmt_semi ( black_box1) ) ;
197
+ //body.stmts.push(ecx.stmt_semi (black_box2));
198
198
body. stmts . push ( ecx. stmt_expr ( loop_expr) ) ;
199
199
body
200
200
}
201
201
202
202
fn gen_primal_call ( ecx : & ExtCtxt < ' _ > , span : Span , primal : Ident , sig : & ast:: FnSig , idents : Vec < Ident > ) -> P < ast:: Expr > {
203
- //pub struct Param {
204
- // pub attrs: AttrVec,
205
- // pub ty: P<Ty>,
206
- // pub pat: P<Pat>,
207
- // pub id: NodeId,
208
- // pub span: Span,
209
- // pub is_placeholder: bool,
210
- //}
211
203
let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
212
204
let args = idents. iter ( ) . map ( |arg| {
213
205
ecx. expr_path ( ecx. path_ident ( span, * arg) )
@@ -228,16 +220,14 @@ fn gen_primal_call(ecx: &ExtCtxt<'_>, span: Span, primal: Ident, sig: &ast::FnSi
228
220
// activity.
229
221
fn gen_enzyme_decl ( _ecx : & ExtCtxt < ' _ > , sig : & ast:: FnSig , x : & AutoDiffAttrs , span : Span , _sig_span : Span )
230
222
-> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
231
- let decl: P < ast:: FnDecl > = sig. decl . clone ( ) ;
232
- assert ! ( decl. inputs. len( ) == x. input_activity. len( ) ) ;
233
- assert ! ( decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
234
- let mut d_decl = decl. clone ( ) ;
223
+ assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
224
+ assert ! ( sig. decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
225
+ let mut d_decl = sig. decl . clone ( ) ;
235
226
let mut d_inputs = Vec :: new ( ) ;
236
227
let mut new_inputs = Vec :: new ( ) ;
237
228
let mut old_names = Vec :: new ( ) ;
238
229
let mut idents = Vec :: new ( ) ;
239
- for ( arg, activity) in decl. inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
240
- //dbg!(&arg);
230
+ for ( arg, activity) in sig. decl . inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
241
231
d_inputs. push ( arg. clone ( ) ) ;
242
232
match activity {
243
233
DiffActivity :: Duplicated | DiffActivity :: Dual => {
@@ -273,7 +263,42 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
273
263
//idents.push(ident);
274
264
d_inputs. push ( shadow_arg) ;
275
265
}
276
- _ => { } ,
266
+ _ => { dbg ! ( & activity) ; } ,
267
+ }
268
+ }
269
+
270
+ // If we return a scalar in the primal and the scalar is active,
271
+ // then add it as last arg to the inputs.
272
+ if x. mode == DiffMode :: Reverse {
273
+ match x. ret_activity {
274
+ DiffActivity :: Active => {
275
+ let ty = match d_decl. output {
276
+ rustc_ast:: FnRetTy :: Ty ( ref ty) => ty. clone ( ) ,
277
+ rustc_ast:: FnRetTy :: Default ( span) => {
278
+ panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
279
+ }
280
+ } ;
281
+ let name = "dret" . to_string ( ) ;
282
+ let ident = Ident :: from_str_and_span ( & name, ty. span ) ;
283
+ let shadow_arg = ast:: Param {
284
+ attrs : ThinVec :: new ( ) ,
285
+ ty : ty. clone ( ) ,
286
+ pat : P ( ast:: Pat {
287
+ id : ast:: DUMMY_NODE_ID ,
288
+ kind : PatKind :: Ident ( BindingAnnotation :: NONE ,
289
+ ident,
290
+ None ,
291
+ ) ,
292
+ span : ty. span ,
293
+ tokens : None ,
294
+ } ) ,
295
+ id : ast:: DUMMY_NODE_ID ,
296
+ span : ty. span ,
297
+ is_placeholder : false ,
298
+ } ;
299
+ d_inputs. push ( shadow_arg) ;
300
+ }
301
+ _ => { }
277
302
}
278
303
}
279
304
d_decl. inputs = d_inputs. into ( ) ;
0 commit comments