@@ -550,7 +550,11 @@ pub(crate) unsafe fn llvm_optimize(
550
550
opt_level : config:: OptLevel ,
551
551
opt_stage : llvm:: OptStage ,
552
552
first_run : bool ,
553
+ noop : bool ,
553
554
) -> Result < ( ) , FatalError > {
555
+ if noop {
556
+ return Ok ( ( ) ) ;
557
+ }
554
558
// Enzyme:
555
559
// We want to simplify / optimize functions before AD.
556
560
// However, benchmarks show that optimizations increasing the code size
@@ -724,6 +728,13 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
724
728
let last_inst = LLVMRustGetLastInstruction ( bb) . unwrap ( ) ;
725
729
LLVMPositionBuilderAtEnd ( builder, bb) ;
726
730
731
+ let safety_run_checks;
732
+ if std:: env:: var ( "ENZYME_NO_SAFETY_CHECKS" ) . is_ok ( ) {
733
+ safety_run_checks = false ;
734
+ } else {
735
+ safety_run_checks = true ;
736
+ }
737
+
727
738
if inner_param_num == outer_param_num {
728
739
call_args = outer_args;
729
740
} else {
@@ -763,14 +774,18 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
763
774
outer_pos += 3 ;
764
775
inner_pos += 2 ;
765
776
766
- // Now we assert if int1 <= int2
767
- let res = LLVMBuildICmp (
768
- builder,
769
- IntPredicate :: IntULE as u32 ,
770
- outer_arg,
771
- next2_outer_arg,
772
- "safety_check" . as_ptr ( ) as * const c_char ) ;
773
- safety_vals. push ( res) ;
777
+
778
+ if safety_run_checks {
779
+
780
+ // Now we assert if int1 <= int2
781
+ let res = LLVMBuildICmp (
782
+ builder,
783
+ IntPredicate :: IntULE as u32 ,
784
+ outer_arg,
785
+ next2_outer_arg,
786
+ "safety_check" . as_ptr ( ) as * const c_char ) ;
787
+ safety_vals. push ( res) ;
788
+ }
774
789
}
775
790
}
776
791
}
@@ -782,17 +797,18 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
782
797
// Now add the safety checks.
783
798
if !safety_vals. is_empty ( ) {
784
799
dbg ! ( "Adding safety checks" ) ;
800
+ assert ! ( safety_run_checks) ;
785
801
// first we create one bb per check and two more for the fail and success case.
786
802
let fail_bb = LLVMAppendBasicBlockInContext ( llcx, tgt, "ad_safety_fail" . as_ptr ( ) as * const c_char ) ;
787
803
let success_bb = LLVMAppendBasicBlockInContext ( llcx, tgt, "ad_safety_success" . as_ptr ( ) as * const c_char ) ;
788
- let mut err_bb = vec ! [ ] ;
789
- for i in 0 ..safety_vals . len ( ) {
790
- let name : String = format ! ( "ad_safety_err_{}" , i ) ;
791
- err_bb . push ( LLVMAppendBasicBlockInContext ( llcx , tgt , name . as_ptr ( ) as * const c_char ) ) ;
792
- }
793
- for ( i , & val ) in safety_vals . iter ( ) . enumerate ( ) {
794
- LLVMBuildCondBr ( builder, val , err_bb [ i ] , fail_bb ) ;
795
- LLVMPositionBuilderAtEnd ( builder , err_bb [ i] ) ;
804
+ for i in 1 ..safety_vals . len ( ) {
805
+ // 'or' all safety checks together
806
+ // Doing some binary tree style or'ing here would be more efficient,
807
+ // but I assume LLVM will opt it anyway
808
+ let prev = safety_vals [ i - 1 ] ;
809
+ let curr = safety_vals [ i ] ;
810
+ let res = llvm :: LLVMBuildOr ( builder, prev , curr , "safety_check" . as_ptr ( ) as * const c_char ) ;
811
+ safety_vals [ i] = res ;
796
812
}
797
813
LLVMBuildCondBr ( builder, safety_vals. last ( ) . unwrap ( ) , success_bb, fail_bb) ;
798
814
LLVMPositionBuilderAtEnd ( builder, fail_bb) ;
@@ -1194,7 +1210,31 @@ pub(crate) unsafe fn differentiate(
1194
1210
// disables vectorization and loop unrolling
1195
1211
first_run = true ;
1196
1212
}
1197
- llvm_optimize ( cgcx, & diag_handler, module, config, opt_level, opt_stage, first_run) ?;
1213
+ if std:: env:: var ( "ENZYME_ALT_PIPELINE" ) . is_ok ( ) {
1214
+ dbg ! ( "Running first postAD optimization" ) ;
1215
+ first_run = true ;
1216
+ }
1217
+ let noop = false ;
1218
+ llvm_optimize ( cgcx, & diag_handler, module, config, opt_level, opt_stage, first_run, noop) ?;
1219
+ }
1220
+ if std:: env:: var ( "ENZYME_ALT_PIPELINE" ) . is_ok ( ) {
1221
+ dbg ! ( "Running Second postAD optimization" ) ;
1222
+ if let Some ( opt_level) = config. opt_level {
1223
+ let opt_stage = match cgcx. lto {
1224
+ Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
1225
+ Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
1226
+ _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
1227
+ _ => llvm:: OptStage :: PreLinkNoLTO ,
1228
+ } ;
1229
+ let mut first_run = false ;
1230
+ dbg ! ( "Running Module Optimization after differentiation" ) ;
1231
+ if std:: env:: var ( "ENZYME_NO_VEC_UNROLL" ) . is_ok ( ) {
1232
+ // enables vectorization and loop unrolling
1233
+ first_run = false ;
1234
+ }
1235
+ let noop = false ;
1236
+ llvm_optimize ( cgcx, & diag_handler, module, config, opt_level, opt_stage, first_run, noop) ?;
1237
+ }
1198
1238
}
1199
1239
}
1200
1240
@@ -1278,7 +1318,14 @@ pub(crate) unsafe fn optimize(
1278
1318
} ;
1279
1319
// Second run only relevant for AD
1280
1320
let first_run = true ;
1281
- return llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage, first_run) ;
1321
+ let noop;
1322
+ if std:: env:: var ( "ENZYME_ALT_PIPELINE" ) . is_ok ( ) {
1323
+ noop = true ;
1324
+ dbg ! ( "Skipping PreAD optimization" ) ;
1325
+ } else {
1326
+ noop = false ;
1327
+ }
1328
+ return llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop) ;
1282
1329
}
1283
1330
Ok ( ( ) )
1284
1331
}
0 commit comments