3
3
//use crate::util::check_autodiff;
4
4
5
5
use crate :: errors;
6
- use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
6
+ use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode , valid_input_activity } ;
7
7
use rustc_ast:: ptr:: P ;
8
8
use rustc_ast:: token:: { Token , TokenKind } ;
9
9
use rustc_ast:: tokenstream:: * ;
@@ -80,7 +80,7 @@ pub fn expand(
80
80
dbg ! ( & x) ;
81
81
let span = ecx. with_def_site_ctxt ( expand_span) ;
82
82
83
- let ( d_sig, new_args, idents) = gen_enzyme_decl ( & sig, & x, span) ;
83
+ let ( d_sig, new_args, idents) = gen_enzyme_decl ( ecx , & sig, & x, span) ;
84
84
let new_decl_span = d_sig. span ;
85
85
let d_body = gen_enzyme_body (
86
86
ecx,
@@ -175,6 +175,26 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
175
175
ty
176
176
}
177
177
178
+ // TODO We should make this more robust to also
179
+ // accept aliases of f32 and f64
180
+ #[ cfg( llvm_enzyme) ]
181
+ fn is_float ( ty : & ast:: Ty ) -> bool {
182
+ match ty. kind {
183
+ TyKind :: Path ( _, ref path) => {
184
+ let last = path. segments . last ( ) . unwrap ( ) ;
185
+ last. ident . name == sym:: f32 || last. ident . name == sym:: f64
186
+ }
187
+ _ => false ,
188
+ }
189
+ }
190
+ #[ cfg( llvm_enzyme) ]
191
+ fn is_ptr_or_ref ( ty : & ast:: Ty ) -> bool {
192
+ match ty. kind {
193
+ TyKind :: Ptr ( _) | TyKind :: Ref ( _, _) => true ,
194
+ _ => false ,
195
+ }
196
+ }
197
+
178
198
// The body of our generated functions will consist of two black_Box calls.
179
199
// The first will call the primal function with the original arguments.
180
200
// The second will just take a tuple containing the new arguments.
@@ -259,6 +279,7 @@ fn gen_primal_call(
259
279
// activity.
260
280
#[ cfg( llvm_enzyme) ]
261
281
fn gen_enzyme_decl (
282
+ ecx : & ExtCtxt < ' _ > ,
262
283
sig : & ast:: FnSig ,
263
284
x : & AutoDiffAttrs ,
264
285
span : Span ,
@@ -273,31 +294,50 @@ fn gen_enzyme_decl(
273
294
let mut act_ret = ThinVec :: new ( ) ;
274
295
for ( arg, activity) in sig. decl . inputs . iter ( ) . zip ( x. input_activity . iter ( ) ) {
275
296
d_inputs. push ( arg. clone ( ) ) ;
297
+ if !valid_input_activity ( x. mode , * activity) {
298
+ ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffInvalidApplicationModeAct {
299
+ span,
300
+ mode : x. mode . to_string ( ) ,
301
+ act : activity. to_string ( )
302
+ } ) ;
303
+ }
276
304
match activity {
277
305
DiffActivity :: Active => {
278
- assert ! ( x . mode == DiffMode :: Reverse ) ;
306
+ assert ! ( is_float ( & arg . ty ) ) ;
279
307
act_ret. push ( arg. ty . clone ( ) ) ;
280
308
}
281
- DiffActivity :: Duplicated | DiffActivity :: Dual => {
309
+ DiffActivity :: Duplicated => {
310
+ assert ! ( is_ptr_or_ref( & arg. ty) ) ;
282
311
let mut shadow_arg = arg. clone ( ) ;
283
312
// We += into the shadow in reverse mode.
284
- // Otherwise copy mutability of the original argument.
285
- if activity == & DiffActivity :: Duplicated {
286
- shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
287
- }
288
- // adjust name depending on mode
313
+ shadow_arg. ty = P ( assure_mut_ref ( & arg. ty ) ) ;
289
314
let old_name = if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
290
315
ident. name
291
316
} else {
292
317
dbg ! ( & shadow_arg. pat) ;
293
318
panic ! ( "not an ident?" ) ;
294
319
} ;
295
- let name: String = match x. mode {
296
- DiffMode :: Reverse => format ! ( "d{}" , old_name) ,
297
- DiffMode :: Forward => format ! ( "b{}" , old_name) ,
298
- _ => panic ! ( "unsupported mode: {}" , old_name) ,
320
+ let name: String = format ! ( "d{}" , old_name) ;
321
+ new_inputs. push ( name. clone ( ) ) ;
322
+ let ident = Ident :: from_str_and_span ( & name, shadow_arg. pat . span ) ;
323
+ shadow_arg. pat = P ( ast:: Pat {
324
+ // TODO: Check id
325
+ id : ast:: DUMMY_NODE_ID ,
326
+ kind : PatKind :: Ident ( BindingAnnotation :: NONE , ident, None ) ,
327
+ span : shadow_arg. pat . span ,
328
+ tokens : shadow_arg. pat . tokens . clone ( ) ,
329
+ } ) ;
330
+ d_inputs. push ( shadow_arg) ;
331
+ }
332
+ DiffActivity :: Dual => {
333
+ let mut shadow_arg = arg. clone ( ) ;
334
+ let old_name = if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
335
+ ident. name
336
+ } else {
337
+ dbg ! ( & shadow_arg. pat) ;
338
+ panic ! ( "not an ident?" ) ;
299
339
} ;
300
- dbg ! ( & name ) ;
340
+ let name : String = format ! ( "b{}" , old_name ) ;
301
341
new_inputs. push ( name. clone ( ) ) ;
302
342
let ident = Ident :: from_str_and_span ( & name, shadow_arg. pat . span ) ;
303
343
shadow_arg. pat = P ( ast:: Pat {
@@ -311,6 +351,7 @@ fn gen_enzyme_decl(
311
351
}
312
352
_ => {
313
353
dbg ! ( & activity) ;
354
+ panic ! ( "Not implemented" ) ;
314
355
}
315
356
}
316
357
if let PatKind :: Ident ( _, ident, _) = arg. pat . kind {
0 commit comments