@@ -19,6 +19,8 @@ use std::string::String;
19
19
use thin_vec:: { thin_vec, ThinVec } ;
20
20
use std:: str:: FromStr ;
21
21
22
+ use rustc_ast:: AssocItemKind ;
23
+
22
24
#[ cfg( not( llvm_enzyme) ) ]
23
25
pub fn expand (
24
26
ecx : & mut ExtCtxt < ' _ > ,
@@ -82,30 +84,62 @@ pub fn expand(
82
84
ecx : & mut ExtCtxt < ' _ > ,
83
85
expand_span : Span ,
84
86
meta_item : & ast:: MetaItem ,
85
- item : Annotatable ,
87
+ mut item : Annotatable ,
86
88
) -> Vec < Annotatable > {
87
89
//check_builtin_macro_attribute(ecx, meta_item, sym::alloc_error_handler);
88
90
91
+ // first get the annotable item:
92
+ let ( sig, is_impl) : ( FnSig , bool ) = match & item {
93
+ Annotatable :: Item ( ref iitem) => {
94
+ let sig = match & iitem. kind {
95
+ ItemKind :: Fn ( box ast:: Fn { sig, .. } ) => sig,
96
+ _ => {
97
+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
98
+ return vec ! [ item] ;
99
+ }
100
+ } ;
101
+ ( sig. clone ( ) , false )
102
+ } ,
103
+ Annotatable :: ImplItem ( ref assoc_item) => {
104
+ let sig = match & assoc_item. kind {
105
+ ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, .. } ) => sig,
106
+ _ => {
107
+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
108
+ return vec ! [ item] ;
109
+ }
110
+ } ;
111
+ ( sig. clone ( ) , true )
112
+ } ,
113
+ _ => {
114
+ dbg ! ( & item) ;
115
+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
116
+ return vec ! [ item] ;
117
+ }
118
+ } ;
119
+
89
120
let meta_item_vec: ThinVec < NestedMetaItem > = match meta_item. kind {
90
121
ast:: MetaItemKind :: List ( ref vec) => vec. clone ( ) ,
91
122
_ => {
92
123
ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
93
124
return vec ! [ item] ;
94
125
}
95
126
} ;
96
- // Allow using `#[autodiff(...)]` only on a Fn
97
- let ( has_ret, sig, sig_span) = if let Annotatable :: Item ( item) = & item
98
- && let ItemKind :: Fn ( box ast:: Fn { sig, .. } ) = & item. kind
99
- {
100
- ( sig. decl . output . has_ret ( ) , sig, ecx. with_call_site_ctxt ( sig. span ) )
101
- } else {
102
- ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
103
- return vec ! [ item] ;
104
- } ;
105
127
106
- // Now we know that item is a Item::Fn
107
- let mut orig_item: P < ast:: Item > = item. clone ( ) . expect_item ( ) ;
108
- let primal = orig_item. ident . clone ( ) ;
128
+ let has_ret = sig. decl . output . has_ret ( ) ;
129
+ let sig_span = ecx. with_call_site_ctxt ( sig. span ) ;
130
+
131
+ let ( vis, primal) = match & item {
132
+ Annotatable :: Item ( ref iitem) => {
133
+ ( iitem. vis . clone ( ) , iitem. ident . clone ( ) )
134
+ } ,
135
+ Annotatable :: ImplItem ( ref assoc_item) => {
136
+ ( assoc_item. vis . clone ( ) , assoc_item. ident . clone ( ) )
137
+ } ,
138
+ _ => {
139
+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
140
+ return vec ! [ item] ;
141
+ }
142
+ } ;
109
143
110
144
// create TokenStream from vec elemtents:
111
145
// meta_item doesn't have a .tokens field
@@ -154,12 +188,12 @@ pub fn expand(
154
188
let d_ident = first_ident ( & meta_item_vec[ 0 ] ) ;
155
189
156
190
// The first element of it is the name of the function to be generated
157
- let asdf = ItemKind :: Fn ( Box :: new ( ast:: Fn {
191
+ let asdf = Box :: new ( ast:: Fn {
158
192
defaultness : ast:: Defaultness :: Final ,
159
193
sig : d_sig,
160
194
generics : Generics :: default ( ) ,
161
195
body : Some ( d_body) ,
162
- } ) ) ;
196
+ } ) ;
163
197
let mut rustc_ad_attr =
164
198
P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
165
199
let ts2: Vec < TokenTree > = vec ! [
@@ -195,13 +229,6 @@ pub fn expand(
195
229
style : ast:: AttrStyle :: Outer ,
196
230
span,
197
231
} ;
198
- // don't add it multiple times:
199
- if !orig_item. attrs . iter ( ) . any ( |a| a. id == attr. id ) {
200
- orig_item. attrs . push ( attr. clone ( ) ) ;
201
- }
202
- if !orig_item. attrs . iter ( ) . any ( |a| a. id == inline_never. id ) {
203
- orig_item. attrs . push ( inline_never) ;
204
- }
205
232
206
233
// Now update for d_fn
207
234
rustc_ad_attr. item . args = rustc_ast:: AttrArgs :: Delimited ( rustc_ast:: DelimArgs {
@@ -210,13 +237,51 @@ pub fn expand(
210
237
tokens : ts,
211
238
} ) ;
212
239
attr. kind = ast:: AttrKind :: Normal ( rustc_ad_attr) ;
213
- let mut d_fn = ecx. item ( span, d_ident, thin_vec ! [ attr] , asdf) ;
214
240
215
- // Copy visibility from original function
216
- d_fn. vis = orig_item. vis . clone ( ) ;
241
+ // Don't add it multiple times:
242
+ let orig_annotatable: Annotatable = match item {
243
+ Annotatable :: Item ( ref mut iitem) => {
244
+ if !iitem. attrs . iter ( ) . any ( |a| a. id == attr. id ) {
245
+ iitem. attrs . push ( attr. clone ( ) ) ;
246
+ }
247
+ if !iitem. attrs . iter ( ) . any ( |a| a. id == inline_never. id ) {
248
+ iitem. attrs . push ( inline_never. clone ( ) ) ;
249
+ }
250
+ Annotatable :: Item ( iitem. clone ( ) )
251
+ } ,
252
+ Annotatable :: ImplItem ( ref mut assoc_item) => {
253
+ if !assoc_item. attrs . iter ( ) . any ( |a| a. id == attr. id ) {
254
+ assoc_item. attrs . push ( attr. clone ( ) ) ;
255
+ }
256
+ if !assoc_item. attrs . iter ( ) . any ( |a| a. id == inline_never. id ) {
257
+ assoc_item. attrs . push ( inline_never. clone ( ) ) ;
258
+ }
259
+ Annotatable :: ImplItem ( assoc_item. clone ( ) )
260
+ } ,
261
+ _ => {
262
+ panic ! ( "not supported" ) ;
263
+ }
264
+ } ;
265
+
266
+ let d_annotatable = if is_impl {
267
+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
268
+ let d_fn = P ( ast:: AssocItem {
269
+ attrs : thin_vec ! [ attr. clone( ) , inline_never] ,
270
+ id : ast:: DUMMY_NODE_ID ,
271
+ span,
272
+ vis,
273
+ ident : d_ident,
274
+ kind : assoc_item,
275
+ tokens : None ,
276
+ } ) ;
277
+ Annotatable :: ImplItem ( d_fn)
278
+ } else {
279
+ let mut d_fn = ecx. item ( span, d_ident, thin_vec ! [ attr. clone( ) ] , ItemKind :: Fn ( asdf) ) ;
280
+ d_fn. vis = vis;
281
+ Annotatable :: Item ( d_fn)
282
+ } ;
283
+ trace ! ( "Generated function: {:?}" , d_annotatable) ;
217
284
218
- let orig_annotatable = Annotatable :: Item ( orig_item) ;
219
- let d_annotatable = Annotatable :: Item ( d_fn) ;
220
285
return vec ! [ orig_annotatable, d_annotatable] ;
221
286
}
222
287
@@ -403,10 +468,16 @@ fn gen_primal_call(
403
468
primal : Ident ,
404
469
idents : Vec < Ident > ,
405
470
) -> P < ast:: Expr > {
406
- let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
407
- let args = idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
408
- let primal_call = ecx. expr_call ( span, primal_call_expr, args) ;
409
- primal_call
471
+ let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
472
+ if has_self {
473
+ let args: ThinVec < _ > = idents[ 1 ..] . iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
474
+ let self_expr = ecx. expr_self ( span) ;
475
+ ecx. expr_method_call ( span, self_expr, primal, args. clone ( ) )
476
+ } else {
477
+ let args: ThinVec < _ > = idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
478
+ let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
479
+ ecx. expr_call ( span, primal_call_expr, args)
480
+ }
410
481
}
411
482
412
483
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
@@ -427,7 +498,6 @@ fn gen_enzyme_decl(
427
498
let mut d_decl = sig. decl . clone ( ) ;
428
499
let mut d_inputs = Vec :: new ( ) ;
429
500
let mut new_inputs = Vec :: new ( ) ;
430
- //let mut old_names = Vec::new();
431
501
let mut idents = Vec :: new ( ) ;
432
502
let mut act_ret = ThinVec :: new ( ) ;
433
503
for ( arg, activity) in sig. decl . inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
0 commit comments