@@ -12,7 +12,7 @@ use tracing::debug;
1212use crate :: builder:: { Builder , PlaceRef , UNNAMED } ;
1313use crate :: context:: SimpleCx ;
1414use crate :: declare:: declare_simple_fn;
15- use crate :: llvm:: { self , Metadata , TRUE , Type , Value } ;
15+ use crate :: llvm:: { self , TRUE , Type , Value } ;
1616
1717pub ( crate ) fn adjust_activity_to_abi < ' tcx > (
1818 tcx : TyCtxt < ' tcx > ,
@@ -143,9 +143,9 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
143143 cx : & SimpleCx < ' ll > ,
144144 builder : & mut Builder < ' _ , ' ll , ' tcx > ,
145145 width : u32 ,
146- args : & mut Vec < & ' ll llvm :: Value > ,
146+ args : & mut Vec < & ' ll Value > ,
147147 inputs : & [ DiffActivity ] ,
148- outer_args : & [ & ' ll llvm :: Value ] ,
148+ outer_args : & [ & ' ll Value ] ,
149149) {
150150 debug ! ( "matching autodiff arguments" ) ;
151151 // We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -157,32 +157,36 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
157157 let mut outer_pos: usize = 0 ;
158158 let mut activity_pos = 0 ;
159159
160- let enzyme_const = cx. create_metadata ( b"enzyme_const" ) ;
161- let enzyme_out = cx. create_metadata ( b"enzyme_out" ) ;
162- let enzyme_dup = cx. create_metadata ( b"enzyme_dup" ) ;
163- let enzyme_dupv = cx. create_metadata ( b"enzyme_dupv" ) ;
164- let enzyme_dupnoneed = cx. create_metadata ( b"enzyme_dupnoneed" ) ;
165- let enzyme_dupnoneedv = cx. create_metadata ( b"enzyme_dupnoneedv" ) ;
160+ // We used to use llvm's metadata to instruct enzyme how to differentiate a function.
161+ // In debug mode we would use incremental compilation which caused the metadata to be
162+ // dropped. This is prevented by now using named globals, which are also understood
163+ // by Enzyme.
164+ let global_const = cx. declare_global ( "enzyme_const" , cx. type_ptr ( ) ) ;
165+ let global_out = cx. declare_global ( "enzyme_out" , cx. type_ptr ( ) ) ;
166+ let global_dup = cx. declare_global ( "enzyme_dup" , cx. type_ptr ( ) ) ;
167+ let global_dupv = cx. declare_global ( "enzyme_dupv" , cx. type_ptr ( ) ) ;
168+ let global_dupnoneed = cx. declare_global ( "enzyme_dupnoneed" , cx. type_ptr ( ) ) ;
169+ let global_dupnoneedv = cx. declare_global ( "enzyme_dupnoneedv" , cx. type_ptr ( ) ) ;
166170
167171 while activity_pos < inputs. len ( ) {
168172 let diff_activity = inputs[ activity_pos as usize ] ;
169173 // Duplicated arguments received a shadow argument, into which enzyme will write the
170174 // gradient.
171- let ( activity, duplicated) : ( & Metadata , bool ) = match diff_activity {
175+ let ( activity, duplicated) : ( & Value , bool ) = match diff_activity {
172176 DiffActivity :: None => panic ! ( "not a valid input activity" ) ,
173- DiffActivity :: Const => ( enzyme_const , false ) ,
174- DiffActivity :: Active => ( enzyme_out , false ) ,
175- DiffActivity :: ActiveOnly => ( enzyme_out , false ) ,
176- DiffActivity :: Dual => ( enzyme_dup , true ) ,
177- DiffActivity :: Dualv => ( enzyme_dupv , true ) ,
178- DiffActivity :: DualOnly => ( enzyme_dupnoneed , true ) ,
179- DiffActivity :: DualvOnly => ( enzyme_dupnoneedv , true ) ,
180- DiffActivity :: Duplicated => ( enzyme_dup , true ) ,
181- DiffActivity :: DuplicatedOnly => ( enzyme_dupnoneed , true ) ,
182- DiffActivity :: FakeActivitySize ( _) => ( enzyme_const , false ) ,
177+ DiffActivity :: Const => ( global_const , false ) ,
178+ DiffActivity :: Active => ( global_out , false ) ,
179+ DiffActivity :: ActiveOnly => ( global_out , false ) ,
180+ DiffActivity :: Dual => ( global_dup , true ) ,
181+ DiffActivity :: Dualv => ( global_dupv , true ) ,
182+ DiffActivity :: DualOnly => ( global_dupnoneed , true ) ,
183+ DiffActivity :: DualvOnly => ( global_dupnoneedv , true ) ,
184+ DiffActivity :: Duplicated => ( global_dup , true ) ,
185+ DiffActivity :: DuplicatedOnly => ( global_dupnoneed , true ) ,
186+ DiffActivity :: FakeActivitySize ( _) => ( global_const , false ) ,
183187 } ;
184188 let outer_arg = outer_args[ outer_pos] ;
185- args. push ( cx . get_metadata_value ( activity) ) ;
189+ args. push ( activity) ;
186190 if matches ! ( diff_activity, DiffActivity :: Dualv ) {
187191 let next_outer_arg = outer_args[ outer_pos + 1 ] ;
188192 let elem_bytes_size: u64 = match inputs[ activity_pos + 1 ] {
@@ -242,7 +246,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
242246 assert_eq ! ( cx. type_kind( next_outer_ty3) , TypeKind :: Integer ) ;
243247 args. push ( next_outer_arg2) ;
244248 }
245- args. push ( cx . get_metadata_value ( enzyme_const ) ) ;
249+ args. push ( global_const ) ;
246250 args. push ( next_outer_arg) ;
247251 outer_pos += 2 + 2 * iterations;
248252 activity_pos += 2 ;
@@ -351,13 +355,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
351355 let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
352356 args. push ( fn_to_diff) ;
353357
354- let enzyme_primal_ret = cx. create_metadata ( b "enzyme_primal_return") ;
358+ let global_primal_ret = cx. declare_global ( "enzyme_primal_return" , cx . type_ptr ( ) ) ;
355359 if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
356- args. push ( cx . get_metadata_value ( enzyme_primal_ret ) ) ;
360+ args. push ( global_primal_ret ) ;
357361 }
358362 if attrs. width > 1 {
359- let enzyme_width = cx. create_metadata ( b "enzyme_width") ;
360- args. push ( cx . get_metadata_value ( enzyme_width ) ) ;
363+ let global_width = cx. declare_global ( "enzyme_width" , cx . type_ptr ( ) ) ;
364+ args. push ( global_width ) ;
361365 args. push ( cx. get_const_int ( cx. type_i64 ( ) , attrs. width as u64 ) ) ;
362366 }
363367
0 commit comments