1
1
#![ allow( unused_imports) ]
2
- #![ allow( unused_variables) ]
3
- #![ allow( unused_mut) ]
4
2
//use crate::util::check_builtin_macro_attribute;
5
3
//use crate::util::check_autodiff;
6
4
@@ -20,12 +18,25 @@ use rustc_span::Symbol;
20
18
use std:: string:: String ;
21
19
use thin_vec:: { thin_vec, ThinVec } ;
22
20
21
+ #[ cfg( llvm_enzyme) ]
23
22
fn first_ident ( x : & NestedMetaItem ) -> rustc_span:: symbol:: Ident {
24
23
let segments = & x. meta_item ( ) . unwrap ( ) . path . segments ;
25
24
assert ! ( segments. len( ) == 1 ) ;
26
25
segments[ 0 ] . ident
27
26
}
28
27
28
+ #[ cfg( not( llvm_enzyme) ) ]
29
+ pub fn expand (
30
+ ecx : & mut ExtCtxt < ' _ > ,
31
+ _expand_span : Span ,
32
+ meta_item : & ast:: MetaItem ,
33
+ item : Annotatable ,
34
+ ) -> Vec < Annotatable > {
35
+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffSupportNotBuild { span : meta_item. span } ) ;
36
+ return vec ! [ item] ;
37
+ }
38
+
39
+ #[ cfg( llvm_enzyme) ]
29
40
pub fn expand (
30
41
ecx : & mut ExtCtxt < ' _ > ,
31
42
expand_span : Span ,
@@ -45,24 +56,16 @@ pub fn expand(
45
56
let primal = orig_item. ident . clone ( ) ;
46
57
47
58
// Allow using `#[autodiff(...)]` only on a Fn
48
- let ( fn_item , has_ret, sig, sig_span) = if let Annotatable :: Item ( item) = & item
59
+ let ( has_ret, sig, sig_span) = if let Annotatable :: Item ( item) = & item
49
60
&& let ItemKind :: Fn ( box ast:: Fn { sig, .. } ) = & item. kind
50
61
{
51
- ( item , sig. decl . output . has_ret ( ) , sig, ecx. with_call_site_ctxt ( sig. span ) )
62
+ ( sig. decl . output . has_ret ( ) , sig, ecx. with_call_site_ctxt ( sig. span ) )
52
63
} else {
53
64
ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
54
65
return vec ! [ item] ;
55
66
} ;
56
67
// create TokenStream from vec elemtents:
57
68
// meta_item doesn't have a .tokens field
58
- let ts: Vec < Token > = meta_item_vec. clone ( ) [ 1 ..]
59
- . iter ( )
60
- . map ( |x| {
61
- let val = first_ident ( x) ;
62
- let t = Token :: from_ast_ident ( val) ;
63
- t
64
- } )
65
- . collect ( ) ;
66
69
let comma: Token = Token :: new ( TokenKind :: Comma , Span :: default ( ) ) ;
67
70
let mut ts: Vec < TokenTree > = vec ! [ ] ;
68
71
for t in meta_item_vec. clone ( ) [ 1 ..] . iter ( ) {
@@ -77,18 +80,15 @@ pub fn expand(
77
80
dbg ! ( & x) ;
78
81
let span = ecx. with_def_site_ctxt ( expand_span) ;
79
82
80
- let ( d_sig, old_names , new_args, idents) = gen_enzyme_decl ( & sig, & x, span) ;
83
+ let ( d_sig, new_args, idents) = gen_enzyme_decl ( & sig, & x, span) ;
81
84
let new_decl_span = d_sig. span ;
82
85
let d_body = gen_enzyme_body (
83
86
ecx,
84
87
primal,
85
- & old_names,
86
88
& new_args,
87
89
span,
88
90
sig_span,
89
91
new_decl_span,
90
- & sig,
91
- & d_sig,
92
92
idents,
93
93
) ;
94
94
let d_ident = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) . path . segments [ 0 ] . ident ;
@@ -102,7 +102,7 @@ pub fn expand(
102
102
} ) ) ;
103
103
let mut rustc_ad_attr =
104
104
P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
105
- let mut attr: ast:: Attribute = ast:: Attribute {
105
+ let attr: ast:: Attribute = ast:: Attribute {
106
106
kind : ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) ,
107
107
id : ast:: AttrId :: from_u32 ( 0 ) ,
108
108
style : ast:: AttrStyle :: Outer ,
@@ -116,7 +116,7 @@ pub fn expand(
116
116
delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
117
117
tokens : ts,
118
118
} ) ;
119
- let mut attr2: ast:: Attribute = ast:: Attribute {
119
+ let attr2: ast:: Attribute = ast:: Attribute {
120
120
kind : ast:: AttrKind :: Normal ( rustc_ad_attr) ,
121
121
id : ast:: AttrId :: from_u32 ( 0 ) ,
122
122
style : ast:: AttrStyle :: Outer ,
@@ -131,6 +131,7 @@ pub fn expand(
131
131
}
132
132
133
133
// shadow arguments must be mutable references or ptrs, because Enzyme will write into them.
134
+ #[ cfg( llvm_enzyme) ]
134
135
fn assure_mut_ref ( ty : & ast:: Ty ) -> ast:: Ty {
135
136
let mut ty = ty. clone ( ) ;
136
137
match ty. kind {
@@ -152,37 +153,21 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
152
153
// The second will just take a tuple containing the new arguments.
153
154
// This way we surpress rustc from optimizing any argument away.
154
155
// The last line will 'loop {}', to match the return type of the new function
156
+ #[ cfg( llvm_enzyme) ]
155
157
fn gen_enzyme_body (
156
158
ecx : & ExtCtxt < ' _ > ,
157
159
primal : Ident ,
158
- old_names : & [ String ] ,
159
160
new_names : & [ String ] ,
160
161
span : Span ,
161
162
sig_span : Span ,
162
163
new_decl_span : Span ,
163
- sig : & ast:: FnSig ,
164
- d_sig : & ast:: FnSig ,
165
164
idents : Vec < Ident > ,
166
165
) -> P < ast:: Block > {
167
166
let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
168
- let zeroed_path = ecx. std_path ( & [ Symbol :: intern ( "mem" ) , Symbol :: intern ( "zeroed" ) ] ) ;
169
167
let empty_loop_block = ecx. block ( span, ThinVec :: new ( ) ) ;
170
168
let loop_expr = ecx. expr_loop ( span, empty_loop_block) ;
171
-
172
169
let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
173
- let zeroed_call_expr = ecx. expr_path ( ecx. path ( span, zeroed_path) ) ;
174
-
175
- let mem_zeroed_call: Stmt =
176
- ecx. stmt_expr ( ecx. expr_call ( span, zeroed_call_expr. clone ( ) , thin_vec ! [ ] ) ) ;
177
- let unsafe_block_with_zeroed_call: P < ast:: Expr > = ecx. expr_block ( P ( ast:: Block {
178
- stmts : thin_vec ! [ mem_zeroed_call] ,
179
- id : ast:: DUMMY_NODE_ID ,
180
- rules : ast:: BlockCheckMode :: Unsafe ( ast:: UserProvided ) ,
181
- span : sig_span,
182
- tokens : None ,
183
- could_be_bare_literal : false ,
184
- } ) ) ;
185
- let primal_call = gen_primal_call ( ecx, span, primal, sig, idents) ;
170
+ let primal_call = gen_primal_call ( ecx, span, primal, idents) ;
186
171
// create ::core::hint::black_box(array(arr));
187
172
let black_box_primal_call =
188
173
ecx. expr_call ( new_decl_span, blackbox_call_expr. clone ( ) , thin_vec ! [ primal_call. clone( ) ] ) ;
@@ -207,11 +192,11 @@ fn gen_enzyme_body(
207
192
body
208
193
}
209
194
195
+ #[ cfg( llvm_enzyme) ]
210
196
fn gen_primal_call (
211
197
ecx : & ExtCtxt < ' _ > ,
212
198
span : Span ,
213
199
primal : Ident ,
214
- sig : & ast:: FnSig ,
215
200
idents : Vec < Ident > ,
216
201
) -> P < ast:: Expr > {
217
202
let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
@@ -226,17 +211,18 @@ fn gen_primal_call(
226
211
// zero-initialized by Enzyme). Active arguments are not handled yet.
227
212
// Each argument of the primal function (and the return type if existing) must be annotated with an
228
213
// activity.
214
+ #[ cfg( llvm_enzyme) ]
229
215
fn gen_enzyme_decl (
230
216
sig : & ast:: FnSig ,
231
217
x : & AutoDiffAttrs ,
232
218
span : Span ,
233
- ) -> ( ast:: FnSig , Vec < String > , Vec < String > , Vec < Ident > ) {
219
+ ) -> ( ast:: FnSig , Vec < String > , Vec < Ident > ) {
234
220
assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
235
221
assert ! ( sig. decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
236
222
let mut d_decl = sig. decl . clone ( ) ;
237
223
let mut d_inputs = Vec :: new ( ) ;
238
224
let mut new_inputs = Vec :: new ( ) ;
239
- let mut old_names = Vec :: new ( ) ;
225
+ // let mut old_names = Vec::new();
240
226
let mut idents = Vec :: new ( ) ;
241
227
let mut act_ret = ThinVec :: new ( ) ;
242
228
for ( arg, activity) in sig. decl . inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
@@ -256,7 +242,7 @@ fn gen_enzyme_decl(
256
242
dbg ! ( & shadow_arg. pat) ;
257
243
panic ! ( "not an ident?" ) ;
258
244
} ;
259
- old_names. push ( old_name. to_string ( ) ) ;
245
+ // old_names.push(old_name.to_string());
260
246
let name: String = match x. mode {
261
247
DiffMode :: Reverse => format ! ( "d{}" , old_name) ,
262
248
DiffMode :: Forward => format ! ( "b{}" , old_name) ,
@@ -320,7 +306,7 @@ fn gen_enzyme_decl(
320
306
// return type. This might require changing the return type to a
321
307
// tuple.
322
308
if act_ret. len ( ) > 0 {
323
- let mut ret_ty = match d_decl. output {
309
+ let ret_ty = match d_decl. output {
324
310
FnRetTy :: Ty ( ref ty) => {
325
311
act_ret. insert ( 0 , ty. clone ( ) ) ;
326
312
let kind = TyKind :: Tup ( act_ret) ;
@@ -339,5 +325,5 @@ fn gen_enzyme_decl(
339
325
}
340
326
341
327
let d_sig = FnSig { header : sig. header . clone ( ) , decl : d_decl, span } ;
342
- ( d_sig, old_names , new_inputs, idents)
328
+ ( d_sig, new_inputs, idents)
343
329
}
0 commit comments