1
1
#![ allow( unused_imports) ]
2
+ #![ allow( unused_variables) ]
2
3
//use crate::util::check_builtin_macro_attribute;
3
4
//use crate::util::check_autodiff;
4
5
@@ -32,9 +33,7 @@ pub fn expand(
32
33
return vec ! [ item] ;
33
34
}
34
35
} ;
35
- let input = item. clone ( ) ;
36
36
let orig_item: P < ast:: Item > = item. clone ( ) . expect_item ( ) ;
37
- let mut d_item: P < ast:: Item > = item. clone ( ) . expect_item ( ) ;
38
37
let primal = orig_item. ident . clone ( ) ;
39
38
40
39
// Allow using `#[autodiff(...)]` only on a Fn
@@ -46,29 +45,27 @@ pub fn expand(
46
45
ecx. sess
47
46
. dcx ( )
48
47
. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
49
- return vec ! [ input ] ;
48
+ return vec ! [ item ] ;
50
49
} ;
51
50
let x: AutoDiffAttrs = AutoDiffAttrs :: from_ast ( & meta_item_vec, has_ret) ;
52
51
dbg ! ( & x) ;
53
52
let span = ecx. with_def_site_ctxt ( fn_item. span ) ;
54
53
55
- let ( d_decl , old_names, new_args) = gen_enzyme_decl ( ecx, & sig. decl , & x, span, sig_span) ;
54
+ let ( d_sig , old_names, new_args) = gen_enzyme_decl ( ecx, & sig, & x, span, sig_span) ;
56
55
let d_body = gen_enzyme_body ( ecx, primal, & old_names, & new_args, span, sig_span) ;
57
- let meta_item_name = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) ;
58
- d_item. ident = meta_item_name. path . segments [ 0 ] . ident ;
59
- // update d_item
60
- if let ItemKind :: Fn ( box ast:: Fn { sig, body, .. } ) = & mut d_item. kind {
61
- * sig. decl = d_decl;
62
- * body = Some ( d_body) ;
63
- } else {
64
- ecx. sess
65
- . dcx ( )
66
- . emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
67
- return vec ! [ input] ;
68
- }
56
+ let d_ident = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) . path . segments [ 0 ] . ident ;
57
+
58
+ // The first element of it is the name of the function to be generated
59
+ let asdf = ItemKind :: Fn ( Box :: new ( ast:: Fn {
60
+ defaultness : ast:: Defaultness :: Final ,
61
+ sig : d_sig,
62
+ generics : Generics :: default ( ) ,
63
+ body : Some ( d_body) ,
64
+ } ) ) ;
65
+ let d_fn = ecx. item ( span, d_ident, rustc_ast:: AttrVec :: default ( ) , asdf) ;
69
66
70
67
let orig_annotatable = Annotatable :: Item ( orig_item. clone ( ) ) ;
71
- let d_annotatable = Annotatable :: Item ( d_item . clone ( ) ) ;
68
+ let d_annotatable = Annotatable :: Item ( d_fn ) ;
72
69
return vec ! [ orig_annotatable, d_annotatable] ;
73
70
}
74
71
@@ -98,6 +95,9 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
98
95
fn gen_enzyme_body ( ecx : & ExtCtxt < ' _ > , primal : Ident , old_names : & [ String ] , new_names : & [ String ] , span : Span , sig_span : Span ) -> P < ast:: Block > {
99
96
let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
100
97
let zeroed_path = ecx. std_path ( & [ Symbol :: intern ( "mem" ) , Symbol :: intern ( "zeroed" ) ] ) ;
98
+ let empty_loop_block = ecx. block ( span, ThinVec :: new ( ) ) ;
99
+ let loop_expr = ecx. expr_loop ( span, empty_loop_block) ;
100
+
101
101
102
102
let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
103
103
let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
@@ -117,13 +117,18 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
117
117
could_be_bare_literal : false ,
118
118
} ) ) ;
119
119
// create ::core::hint::black_box(array(arr));
120
- let _primal_call = ecx. expr_call (
120
+ let primal_call = ecx. expr_call (
121
121
span,
122
122
primal_call_expr. clone ( ) ,
123
123
old_names. iter ( ) . map ( |name| {
124
124
ecx. expr_path ( ecx. path_ident ( span, Ident :: from_str ( name) ) )
125
125
} ) . collect ( ) ,
126
126
) ;
127
+ let black_box0 = ecx. expr_call (
128
+ sig_span,
129
+ blackbox_call_expr. clone ( ) ,
130
+ thin_vec ! [ primal_call. clone( ) ] ,
131
+ ) ;
127
132
128
133
// create ::core::hint::black_box(grad_arr, tang_y));
129
134
let black_box1 = ecx. expr_call (
@@ -135,15 +140,18 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
135
140
) ;
136
141
137
142
// create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
138
- let black_box2 = ecx. expr_call (
143
+ let _black_box2 = ecx. expr_call (
139
144
sig_span,
140
145
blackbox_call_expr. clone ( ) ,
141
146
thin_vec ! [ unsafe_block_with_zeroed_call. clone( ) ] ,
142
147
) ;
143
148
144
149
let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
145
- body. stmts . push ( ecx. stmt_expr ( black_box1) ) ;
146
- body. stmts . push ( ecx. stmt_expr ( black_box2) ) ;
150
+ body. stmts . push ( ecx. stmt_expr ( primal_call) ) ;
151
+ //body.stmts.push(ecx.stmt_expr(black_box0));
152
+ //body.stmts.push(ecx.stmt_expr(black_box1));
153
+ //body.stmts.push(ecx.stmt_expr(black_box2));
154
+ body. stmts . push ( ecx. stmt_expr ( loop_expr) ) ;
147
155
body
148
156
}
149
157
@@ -153,16 +161,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
153
161
// zero-initialized by Enzyme). Active arguments are not handled yet.
154
162
// Each argument of the primal function (and the return type if existing) must be annotated with an
155
163
// activity.
156
- fn gen_enzyme_decl ( _ecx : & ExtCtxt < ' _ > , decl : & ast:: FnDecl , x : & AutoDiffAttrs , _span : Span , _sig_span : Span )
157
- -> ( ast:: FnDecl , Vec < String > , Vec < String > ) {
164
+ fn gen_enzyme_decl ( _ecx : & ExtCtxt < ' _ > , sig : & ast:: FnSig , x : & AutoDiffAttrs , span : Span , _sig_span : Span )
165
+ -> ( ast:: FnSig , Vec < String > , Vec < String > ) {
166
+ let decl: P < ast:: FnDecl > = sig. decl . clone ( ) ;
158
167
assert ! ( decl. inputs. len( ) == x. input_activity. len( ) ) ;
159
168
assert ! ( decl. output. has_ret( ) == x. has_ret_activity( ) ) ;
160
169
let mut d_decl = decl. clone ( ) ;
161
170
let mut d_inputs = Vec :: new ( ) ;
162
171
let mut new_inputs = Vec :: new ( ) ;
163
172
let mut old_names = Vec :: new ( ) ;
164
173
for ( arg, activity) in decl. inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
165
- dbg ! ( & arg) ;
174
+ // dbg!(&arg);
166
175
d_inputs. push ( arg. clone ( ) ) ;
167
176
match activity {
168
177
DiffActivity :: Duplicated => {
@@ -200,5 +209,10 @@ fn gen_enzyme_decl(_ecx: &ExtCtxt<'_>, decl: &ast::FnDecl, x: &AutoDiffAttrs, _s
200
209
}
201
210
}
202
211
d_decl. inputs = d_inputs. into ( ) ;
203
- ( d_decl, old_names, new_inputs)
212
+ let d_sig = FnSig {
213
+ header : sig. header . clone ( ) ,
214
+ decl : d_decl,
215
+ span,
216
+ } ;
217
+ ( d_sig, old_names, new_inputs)
204
218
}
0 commit comments