@@ -72,20 +72,15 @@ pub fn expand(
72
72
ts. push ( TokenTree :: Token ( t, Spacing :: Joint ) ) ;
73
73
ts. push ( TokenTree :: Token ( comma. clone ( ) , Spacing :: Alone ) ) ;
74
74
}
75
- dbg ! ( & ts) ;
76
75
let ts: TokenStream = TokenStream :: from_iter ( ts) ;
77
- dbg ! ( & ts) ;
78
76
79
77
let x: AutoDiffAttrs = AutoDiffAttrs :: from_ast ( & meta_item_vec, has_ret) ;
80
78
dbg ! ( & x) ;
81
- //let span = ecx.with_def_site_ctxt(sig_span);
82
79
let span = ecx. with_def_site_ctxt ( expand_span) ;
83
- //let span = ecx.with_def_site_ctxt(fn_item.span);
84
80
85
- let ( d_sig, old_names, new_args) = gen_enzyme_decl ( ecx, & sig, & x, span, sig_span) ;
81
+ let ( d_sig, old_names, new_args, idents ) = gen_enzyme_decl ( ecx, & sig, & x, span, sig_span) ;
86
82
let new_decl_span = d_sig. span ;
87
- //let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, span);
88
- let d_body = gen_enzyme_body ( ecx, primal, & old_names, & new_args, span, sig_span, new_decl_span) ;
83
+ let d_body = gen_enzyme_body ( ecx, primal, & old_names, & new_args, span, sig_span, new_decl_span, & sig, & d_sig, idents) ;
89
84
let d_ident = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) . path . segments [ 0 ] . ident ;
90
85
91
86
// The first element of it is the name of the function to be generated
@@ -147,14 +142,13 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
147
142
// The second will just take the shadow arguments.
148
143
// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
149
144
// (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
150
- fn gen_enzyme_body ( ecx : & ExtCtxt < ' _ > , primal : Ident , old_names : & [ String ] , new_names : & [ String ] , span : Span , sig_span : Span , new_decl_span : Span ) -> P < ast:: Block > {
145
+ fn gen_enzyme_body ( ecx : & ExtCtxt < ' _ > , primal : Ident , old_names : & [ String ] , new_names : & [ String ] , span : Span , sig_span : Span , new_decl_span : Span , sig : & ast :: FnSig , d_sig : & ast :: FnSig , idents : Vec < Ident > ) -> P < ast:: Block > {
151
146
let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
152
147
let zeroed_path = ecx. std_path ( & [ Symbol :: intern ( "mem" ) , Symbol :: intern ( "zeroed" ) ] ) ;
153
148
let empty_loop_block = ecx. block ( span, ThinVec :: new ( ) ) ;
154
149
let loop_expr = ecx. expr_loop ( span, empty_loop_block) ;
155
150
156
151
157
- let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
158
152
let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
159
153
let zeroed_call_expr = ecx. expr_path ( ecx. path ( span, zeroed_path) ) ;
160
154
@@ -172,18 +166,11 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
172
166
could_be_bare_literal : false ,
173
167
} ) ) ;
174
168
// create ::core::hint::black_box(array(arr));
175
- let primal_call = ecx. expr_call (
176
- new_decl_span,
177
- primal_call_expr,
178
- old_names. iter ( ) . map ( |name| {
179
- ecx. expr_path ( ecx. path_ident ( new_decl_span, Ident :: from_str ( name) ) )
180
- } ) . collect ( ) ,
181
- ) ;
182
- let black_box0 = ecx. expr_call (
183
- new_decl_span,
184
- blackbox_call_expr. clone ( ) ,
185
- thin_vec ! [ primal_call. clone( ) ] ,
186
- ) ;
169
+ //let black_box0 = ecx.expr_call(
170
+ // new_decl_span,
171
+ // blackbox_call_expr.clone(),
172
+ // thin_vec![primal_call.clone()],
173
+ //);
187
174
188
175
// create ::core::hint::black_box(grad_arr, tang_y));
189
176
let black_box1 = ecx. expr_call (
@@ -201,30 +188,54 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
201
188
thin_vec ! [ unsafe_block_with_zeroed_call. clone( ) ] ,
202
189
) ;
203
190
191
+ let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
192
+
204
193
let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
205
- // body.stmts.push(ecx.stmt_expr (primal_call));
194
+ body. stmts . push ( ecx. stmt_semi ( primal_call) ) ;
206
195
//body.stmts.push(ecx.stmt_expr(black_box0));
207
196
//body.stmts.push(ecx.stmt_expr(black_box1));
208
- body. stmts . push ( ecx. stmt_expr ( black_box2) ) ;
197
+ // body.stmts.push(ecx.stmt_expr(black_box2));
209
198
body. stmts . push ( ecx. stmt_expr ( loop_expr) ) ;
210
199
body
211
200
}
212
201
202
+ fn gen_primal_call ( ecx : & ExtCtxt < ' _ > , span : Span , primal : Ident , sig : & ast:: FnSig , idents : Vec < Ident > ) -> P < ast:: Expr > {
203
+ //pub struct Param {
204
+ // pub attrs: AttrVec,
205
+ // pub ty: P<Ty>,
206
+ // pub pat: P<Pat>,
207
+ // pub id: NodeId,
208
+ // pub span: Span,
209
+ // pub is_placeholder: bool,
210
+ //}
211
+ let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
212
+ let args = idents. iter ( ) . map ( |arg| {
213
+ ecx. expr_path ( ecx. path_ident ( span, * arg) )
214
+ } ) . collect ( ) ;
215
+ let primal_call = ecx. expr_call (
216
+ span,
217
+ primal_call_expr,
218
+ args,
219
+ ) ;
220
+ primal_call
221
+ }
222
+
213
223
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
214
224
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
215
225
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
216
226
// zero-initialized by Enzyme). Active arguments are not handled yet.
217
227
// Each argument of the primal function (and the return type if existing) must be annotated with an
218
228
// activity.
219
229
fn gen_enzyme_decl ( _ecx : & ExtCtxt < ' _ > , sig : & ast:: FnSig , x : & AutoDiffAttrs , span : Span , _sig_span : Span )
220
- -> ( ast:: FnSig , Vec < String > , Vec < String > ) {
230
+ -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
221
231
let decl: P < ast:: FnDecl > = sig. decl . clone ( ) ;
222
232
assert ! ( decl. inputs. len( ) == x. input_activity. len( ) ) ;
223
233
assert ! ( decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
224
234
let mut d_decl = decl. clone ( ) ;
225
235
let mut d_inputs = Vec :: new ( ) ;
226
236
let mut new_inputs = Vec :: new ( ) ;
227
237
let mut old_names = Vec :: new ( ) ;
238
+ let mut idents = Vec :: new ( ) ;
228
239
for ( arg, activity) in decl. inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
229
240
//dbg!(&arg);
230
241
d_inputs. push ( arg. clone ( ) ) ;
@@ -234,6 +245,7 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
234
245
shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
235
246
// adjust name depending on mode
236
247
let old_name = if let PatKind :: Ident ( _, ident, _) = shadow_arg. pat . kind {
248
+ idents. push ( ident. clone ( ) ) ;
237
249
ident. name
238
250
} else {
239
251
dbg ! ( & shadow_arg. pat) ;
@@ -247,17 +259,18 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
247
259
} ;
248
260
dbg ! ( & name) ;
249
261
new_inputs. push ( name. clone ( ) ) ;
262
+ let ident = Ident :: from_str_and_span ( & name, shadow_arg. pat . span ) ;
250
263
shadow_arg. pat = P ( ast:: Pat {
251
264
// TODO: Check id
252
265
id : ast:: DUMMY_NODE_ID ,
253
266
kind : PatKind :: Ident ( BindingAnnotation :: NONE ,
254
- Ident :: from_str_and_span ( & name , shadow_arg . pat . span ) ,
267
+ ident ,
255
268
None ,
256
269
) ,
257
270
span : shadow_arg. pat . span ,
258
271
tokens : shadow_arg. pat . tokens . clone ( ) ,
259
272
} ) ;
260
-
273
+ //idents.push(ident);
261
274
d_inputs. push ( shadow_arg) ;
262
275
}
263
276
_ => { } ,
@@ -269,5 +282,5 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, sig: &ast::FnSig, x: &AutoDiffAttrs, span
269
282
decl : d_decl,
270
283
span,
271
284
} ;
272
- ( d_sig, old_names, new_inputs)
285
+ ( d_sig, old_names, new_inputs, idents )
273
286
}
0 commit comments