@@ -91,7 +91,7 @@ pub fn expand(
91
91
new_decl_span,
92
92
idents,
93
93
) ;
94
- let d_ident = meta_item_vec[ 0 ] . meta_item ( ) . unwrap ( ) . path . segments [ 0 ] . ident ;
94
+ let d_ident = first_ident ( & meta_item_vec[ 0 ] ) ;
95
95
96
96
// The first element of it is the name of the function to be generated
97
97
let asdf = ItemKind :: Fn ( Box :: new ( ast:: Fn {
@@ -102,11 +102,12 @@ pub fn expand(
102
102
} ) ) ;
103
103
let mut rustc_ad_attr =
104
104
P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_autodiff) ) ) ;
105
- let attr: ast:: Attribute = ast:: Attribute {
105
+ let mut attr: ast:: Attribute = ast:: Attribute {
106
106
kind : ast:: AttrKind :: Normal ( rustc_ad_attr. clone ( ) ) ,
107
- id : ast:: AttrId :: from_u32 ( 0 ) ,
107
+ //id: ast::DUMMY_TR_ID,
108
+ id : ast:: AttrId :: from_u32 ( 12341 ) , // TODO: fix
108
109
style : ast:: AttrStyle :: Outer ,
109
- span : span ,
110
+ span,
110
111
} ;
111
112
orig_item. attrs . push ( attr. clone ( ) ) ;
112
113
@@ -116,21 +117,15 @@ pub fn expand(
116
117
delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
117
118
tokens : ts,
118
119
} ) ;
119
- let attr2: ast:: Attribute = ast:: Attribute {
120
- kind : ast:: AttrKind :: Normal ( rustc_ad_attr) ,
121
- id : ast:: AttrId :: from_u32 ( 0 ) ,
122
- style : ast:: AttrStyle :: Outer ,
123
- span : span,
124
- } ;
125
- let attr_vec: rustc_ast:: AttrVec = thin_vec ! [ attr2] ;
126
- let d_fn = ecx. item ( span, d_ident, attr_vec, asdf) ;
120
+ attr. kind = ast:: AttrKind :: Normal ( rustc_ad_attr) ;
121
+ let d_fn = ecx. item ( span, d_ident, thin_vec ! [ attr] , asdf) ;
127
122
128
- let orig_annotatable = Annotatable :: Item ( orig_item. clone ( ) ) ;
123
+ let orig_annotatable = Annotatable :: Item ( orig_item) ;
129
124
let d_annotatable = Annotatable :: Item ( d_fn) ;
130
125
return vec ! [ orig_annotatable, d_annotatable] ;
131
126
}
132
127
133
- // shadow arguments must be mutable references or ptrs, because Enzyme will write into them.
128
+ // shadow arguments in reverse mode must be mutable references or ptrs, because Enzyme will write into them.
134
129
#[ cfg( llvm_enzyme) ]
135
130
fn assure_mut_ref ( ty : & ast:: Ty ) -> ast:: Ty {
136
131
let mut ty = ty. clone ( ) ;
@@ -165,6 +160,25 @@ fn gen_enzyme_body(
165
160
) -> P < ast:: Block > {
166
161
let blackbox_path = ecx. std_path ( & [ Symbol :: intern ( "hint" ) , Symbol :: intern ( "black_box" ) ] ) ;
167
162
let empty_loop_block = ecx. block ( span, ThinVec :: new ( ) ) ;
163
+ let noop = ast:: InlineAsm {
164
+ template : vec ! [ ast:: InlineAsmTemplatePiece :: String ( "NOP" . to_string( ) ) ] ,
165
+ template_strs : Box :: new ( [ ] ) ,
166
+ operands : vec ! [ ] ,
167
+ clobber_abis : vec ! [ ] ,
168
+ options : ast:: InlineAsmOptions :: PURE & ast:: InlineAsmOptions :: NOMEM ,
169
+ line_spans : vec ! [ ] ,
170
+ } ;
171
+ let noop_expr = ecx. expr_asm ( span, P ( noop) ) ;
172
+ let unsf = ast:: BlockCheckMode :: Unsafe ( ast:: UnsafeSource :: CompilerGenerated ) ;
173
+ let unsf_block = ast:: Block {
174
+ stmts : thin_vec ! [ ecx. stmt_semi( noop_expr) ] ,
175
+ id : ast:: DUMMY_NODE_ID ,
176
+ tokens : None ,
177
+ rules : unsf,
178
+ span,
179
+ could_be_bare_literal : false ,
180
+ } ;
181
+ let unsf_expr = ecx. expr_block ( P ( unsf_block) ) ;
168
182
let loop_expr = ecx. expr_loop ( span, empty_loop_block) ;
169
183
let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
170
184
let primal_call = gen_primal_call ( ecx, span, primal, idents) ;
@@ -185,7 +199,7 @@ fn gen_enzyme_body(
185
199
) ;
186
200
187
201
let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
188
- body. stmts . push ( ecx. stmt_semi ( primal_call ) ) ;
202
+ body. stmts . push ( ecx. stmt_semi ( unsf_expr ) ) ;
189
203
body. stmts . push ( ecx. stmt_semi ( black_box_primal_call) ) ;
190
204
body. stmts . push ( ecx. stmt_semi ( black_box_remaining_args) ) ;
191
205
body. stmts . push ( ecx. stmt_expr ( loop_expr) ) ;
@@ -234,15 +248,18 @@ fn gen_enzyme_decl(
234
248
}
235
249
DiffActivity :: Duplicated | DiffActivity :: Dual => {
236
250
let mut shadow_arg = arg. clone ( ) ;
237
- shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
251
+ // We += into the shadow in reverse mode.
252
+ // Otherwise copy mutability of the original argument.
253
+ if activity == & DiffActivity :: Duplicated {
254
+ shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
255
+ }
238
256
// adjust name depending on mode
239
257
let old_name = if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
240
258
ident. name
241
259
} else {
242
260
dbg ! ( & shadow_arg. pat) ;
243
261
panic ! ( "not an ident?" ) ;
244
262
} ;
245
- //old_names.push(old_name.to_string());
246
263
let name: String = match x. mode {
247
264
DiffMode :: Reverse => format ! ( "d{}" , old_name) ,
248
265
DiffMode :: Forward => format ! ( "b{}" , old_name) ,
0 commit comments