@@ -18,7 +18,7 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
18
18
19
19
pub fn expand (
20
20
ecx : & mut ExtCtxt < ' _ > ,
21
- _span : Span ,
21
+ expand_span : Span ,
22
22
meta_item : & ast:: MetaItem ,
23
23
item : Annotatable ,
24
24
) -> Vec < Annotatable > {
@@ -40,7 +40,7 @@ pub fn expand(
40
40
let ( fn_item, has_ret, sig, sig_span) = if let Annotatable :: Item ( item) = & item
41
41
&& let ItemKind :: Fn ( box ast:: Fn { sig, .. } ) = & item. kind
42
42
{
43
- ( item, sig. decl . output . has_ret ( ) , sig, ecx. with_def_site_ctxt ( sig. span ) )
43
+ ( item, sig. decl . output . has_ret ( ) , sig, ecx. with_call_site_ctxt ( sig. span ) )
44
44
} else {
45
45
ecx. sess
46
46
. dcx ( )
@@ -49,10 +49,14 @@ pub fn expand(
49
49
} ;
50
50
let x: AutoDiffAttrs = AutoDiffAttrs :: from_ast ( & meta_item_vec, has_ret) ;
51
51
dbg ! ( & x) ;
52
- let span = ecx. with_def_site_ctxt ( fn_item. span ) ;
52
+ //let span = ecx.with_def_site_ctxt(sig_span);
53
+ let span = ecx. with_def_site_ctxt ( expand_span) ;
54
+ //let span = ecx.with_def_site_ctxt(fn_item.span);
53
55
54
56
let ( d_sig, old_names, new_args) = gen_enzyme_decl ( ecx, & sig, & x, span, sig_span) ;
55
- let d_body = gen_enzyme_body ( ecx, primal, & old_names, & new_args, span, sig_span) ;
57
+ let new_decl_span = d_sig. span ;
58
+ //let d_body = gen_enzyme_body(ecx, primal, &old_names, &new_args, span, span);
59
+ let d_body = gen_enzyme_body ( ecx, primal, & old_names, & new_args, span, sig_span, new_decl_span) ;
56
60
let d_ident = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) . path . segments [ 0 ] . ident ;
57
61
58
62
// The first element of it is the name of the function to be generated
@@ -92,7 +96,7 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
92
96
// The second will just take the shadow arguments.
93
97
// The third will (unsafely) call std::mem::zeroed(), to match the return type of the new function
94
98
// (whatever that might be). This way we surpress rustc from optimizing anyt argument away.
95
- fn gen_enzyme_body ( ecx : & ExtCtxt < ' _ > , primal : Ident , old_names : & [ String ] , new_names : & [ String ] , span : Span , sig_span : Span ) -> P < ast:: Block > {
99
+ fn gen_enzyme_body ( ecx : & ExtCtxt < ' _ > , primal : Ident , old_names : & [ String ] , new_names : & [ String ] , span : Span , sig_span : Span , new_decl_span : Span ) -> P < ast:: Block > {
96
100
let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
97
101
let zeroed_path = ecx. std_path ( & [ Symbol :: intern ( "mem" ) , Symbol :: intern ( "zeroed" ) ] ) ;
98
102
let empty_loop_block = ecx. block ( span, ThinVec :: new ( ) ) ;
@@ -118,14 +122,14 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
118
122
} ) ) ;
119
123
// create ::core::hint::black_box(array(arr));
120
124
let primal_call = ecx. expr_call (
121
- span ,
122
- primal_call_expr. clone ( ) ,
125
+ new_decl_span ,
126
+ primal_call_expr,
123
127
old_names. iter ( ) . map ( |name| {
124
- ecx. expr_path ( ecx. path_ident ( span , Ident :: from_str ( name) ) )
128
+ ecx. expr_path ( ecx. path_ident ( new_decl_span , Ident :: from_str ( name) ) )
125
129
} ) . collect ( ) ,
126
130
) ;
127
131
let black_box0 = ecx. expr_call (
128
- sig_span ,
132
+ new_decl_span ,
129
133
blackbox_call_expr. clone ( ) ,
130
134
thin_vec ! [ primal_call. clone( ) ] ,
131
135
) ;
@@ -140,17 +144,17 @@ fn gen_enzyme_body(ecx: &ExtCtxt<'_>, primal: Ident, old_names: &[String], new_n
140
144
) ;
141
145
142
146
// create ::core::hint::black_box(unsafe { ::core::mem::zeroed() })
143
- let _black_box2 = ecx. expr_call (
147
+ let black_box2 = ecx. expr_call (
144
148
sig_span,
145
149
blackbox_call_expr. clone ( ) ,
146
150
thin_vec ! [ unsafe_block_with_zeroed_call. clone( ) ] ,
147
151
) ;
148
152
149
153
let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
150
- body. stmts . push ( ecx. stmt_expr ( primal_call) ) ;
154
+ // body.stmts.push(ecx.stmt_expr(primal_call));
151
155
//body.stmts.push(ecx.stmt_expr(black_box0));
152
156
//body.stmts.push(ecx.stmt_expr(black_box1));
153
- // body.stmts.push(ecx.stmt_expr(black_box2));
157
+ body. stmts . push ( ecx. stmt_expr ( black_box2) ) ;
154
158
body. stmts . push ( ecx. stmt_expr ( loop_expr) ) ;
155
159
body
156
160
}
0 commit comments