Skip to content

Commit b5a3130

Browse files
authored
adding two more flags (#109)
1 parent 1e32009 commit b5a3130

File tree

2 files changed

+68
-19
lines changed

2 files changed

+68
-19
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,9 @@ pub(crate) fn run_pass_manager(
618618
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
619619
// We will run this again with different values in the context of automatic differentiation.
620620
let first_run = true;
621-
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
621+
let noop = false;
622+
dbg!("running llvm pm opt pipeline");
623+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop)?;
622624
}
623625
debug!("lto done");
624626
Ok(())

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,11 @@ pub(crate) unsafe fn llvm_optimize(
550550
opt_level: config::OptLevel,
551551
opt_stage: llvm::OptStage,
552552
first_run: bool,
553+
noop: bool,
553554
) -> Result<(), FatalError> {
555+
if noop {
556+
return Ok(());
557+
}
554558
// Enzyme:
555559
// We want to simplify / optimize functions before AD.
556560
// However, benchmarks show that optimizations increasing the code size
@@ -724,6 +728,13 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
724728
let last_inst = LLVMRustGetLastInstruction(bb).unwrap();
725729
LLVMPositionBuilderAtEnd(builder, bb);
726730

731+
let safety_run_checks;
732+
if std::env::var("ENZYME_NO_SAFETY_CHECKS").is_ok() {
733+
safety_run_checks = false;
734+
} else {
735+
safety_run_checks = true;
736+
}
737+
727738
if inner_param_num == outer_param_num {
728739
call_args = outer_args;
729740
} else {
@@ -763,14 +774,18 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
763774
outer_pos += 3;
764775
inner_pos += 2;
765776

766-
// Now we assert if int1 <= int2
767-
let res = LLVMBuildICmp(
768-
builder,
769-
IntPredicate::IntULE as u32,
770-
outer_arg,
771-
next2_outer_arg,
772-
"safety_check".as_ptr() as *const c_char);
773-
safety_vals.push(res);
777+
778+
if safety_run_checks {
779+
780+
// Now we assert if int1 <= int2
781+
let res = LLVMBuildICmp(
782+
builder,
783+
IntPredicate::IntULE as u32,
784+
outer_arg,
785+
next2_outer_arg,
786+
"safety_check".as_ptr() as *const c_char);
787+
safety_vals.push(res);
788+
}
774789
}
775790
}
776791
}
@@ -782,17 +797,18 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
782797
// Now add the safety checks.
783798
if !safety_vals.is_empty() {
784799
dbg!("Adding safety checks");
800+
assert!(safety_run_checks);
785801
// first we create one bb per check and two more for the fail and success case.
786802
let fail_bb = LLVMAppendBasicBlockInContext(llcx, tgt, "ad_safety_fail".as_ptr() as *const c_char);
787803
let success_bb = LLVMAppendBasicBlockInContext(llcx, tgt, "ad_safety_success".as_ptr() as *const c_char);
788-
let mut err_bb = vec![];
789-
for i in 0..safety_vals.len() {
790-
let name: String = format!("ad_safety_err_{}", i);
791-
err_bb.push(LLVMAppendBasicBlockInContext(llcx, tgt, name.as_ptr() as *const c_char));
792-
}
793-
for (i, &val) in safety_vals.iter().enumerate() {
794-
LLVMBuildCondBr(builder, val, err_bb[i], fail_bb);
795-
LLVMPositionBuilderAtEnd(builder, err_bb[i]);
804+
for i in 1..safety_vals.len() {
805+
// 'or' all safety checks together
806+
// Doing some binary tree style or'ing here would be more efficient,
807+
// but I assume LLVM will opt it anyway
808+
let prev = safety_vals[i - 1];
809+
let curr = safety_vals[i];
810+
let res = llvm::LLVMBuildOr(builder, prev, curr, "safety_check".as_ptr() as *const c_char);
811+
safety_vals[i] = res;
796812
}
797813
LLVMBuildCondBr(builder, safety_vals.last().unwrap(), success_bb, fail_bb);
798814
LLVMPositionBuilderAtEnd(builder, fail_bb);
@@ -1194,7 +1210,31 @@ pub(crate) unsafe fn differentiate(
11941210
// disables vectorization and loop unrolling
11951211
first_run = true;
11961212
}
1197-
llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run)?;
1213+
if std::env::var("ENZYME_ALT_PIPELINE").is_ok() {
1214+
dbg!("Running first postAD optimization");
1215+
first_run = true;
1216+
}
1217+
let noop = false;
1218+
llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?;
1219+
}
1220+
if std::env::var("ENZYME_ALT_PIPELINE").is_ok() {
1221+
dbg!("Running Second postAD optimization");
1222+
if let Some(opt_level) = config.opt_level {
1223+
let opt_stage = match cgcx.lto {
1224+
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
1225+
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
1226+
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
1227+
_ => llvm::OptStage::PreLinkNoLTO,
1228+
};
1229+
let mut first_run = false;
1230+
dbg!("Running Module Optimization after differentiation");
1231+
if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() {
1232+
// enables vectorization and loop unrolling
1233+
first_run = false;
1234+
}
1235+
let noop = false;
1236+
llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?;
1237+
}
11981238
}
11991239
}
12001240

@@ -1278,7 +1318,14 @@ pub(crate) unsafe fn optimize(
12781318
};
12791319
// Second run only relevant for AD
12801320
let first_run = true;
1281-
return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run);
1321+
let noop;
1322+
if std::env::var("ENZYME_ALT_PIPELINE").is_ok() {
1323+
noop = true;
1324+
dbg!("Skipping PreAD optimization");
1325+
} else {
1326+
noop = false;
1327+
}
1328+
return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop);
12821329
}
12831330
Ok(())
12841331
}

0 commit comments

Comments
 (0)