Skip to content

Commit dee82b3

Browse files
committed
wire up propper Rust error handler
1 parent d1c94af commit dee82b3

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

compiler/rustc_codegen_llvm/messages.ftl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO
6060
codegen_llvm_run_passes = failed to run LLVM passes
6161
codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err}
6262
63+
codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error}
64+
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
65+
6366
codegen_llvm_sanitizer_memtag_requires_mte =
6467
`-Zsanitizer=memtag` requires `-Ctarget-feature=+mte`
6568

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,7 @@ pub(crate) unsafe fn extract_return_type<'a>(
698698
pub(crate) unsafe fn enzyme_ad(
699699
llmod: &llvm::Module,
700700
llcx: &llvm::Context,
701+
diag_handler: &rustc_errors::Handler,
701702
item: AutoDiffItem,
702703
) -> Result<(), FatalError> {
703704
let autodiff_mode = item.attrs.mode;
@@ -710,8 +711,28 @@ pub(crate) unsafe fn enzyme_ad(
710711
// get target and source function
711712
let name = CString::new(rust_name.to_owned()).unwrap();
712713
let name2 = CString::new(rust_name2.clone()).unwrap();
713-
let src_fnc = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()).unwrap();
714-
let target_fnc = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()).unwrap();
714+
let src_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr());
715+
let src_fnc = match src_fnc_opt {
716+
Some(x) => x,
717+
None => {
718+
return Err(llvm_err(diag_handler, LlvmError::PrepareAutoDiff{
719+
src: rust_name.to_owned(),
720+
target: rust_name2.to_owned(),
721+
error: "could not find src function".to_owned(),
722+
}));
723+
}
724+
};
725+
let target_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr());
726+
let target_fnc = match target_fnc_opt {
727+
Some(x) => x,
728+
None => {
729+
return Err(llvm_err(diag_handler, LlvmError::PrepareAutoDiff{
730+
src: rust_name.to_owned(),
731+
target: rust_name2.to_owned(),
732+
error: "could not find target function".to_owned(),
733+
}));
734+
}
735+
};
715736
let src_num_args = llvm::LLVMCountParams(src_fnc);
716737
let target_num_args = llvm::LLVMCountParams(target_fnc);
717738
assert!(src_num_args <= target_num_args);
@@ -791,13 +812,14 @@ pub(crate) unsafe fn enzyme_ad(
791812

792813
pub(crate) unsafe fn differentiate(
793814
module: &ModuleCodegen<ModuleLlvm>,
794-
_cgcx: &CodegenContext<LlvmCodegenBackend>,
815+
cgcx: &CodegenContext<LlvmCodegenBackend>,
795816
diff_items: Vec<AutoDiffItem>,
796817
_typetrees: FxHashMap<String, DiffTypeTree>,
797818
_config: &ModuleConfig,
798819
) -> Result<(), FatalError> {
799820
let llmod = module.module_llvm.llmod();
800821
let llcx = &module.module_llvm.llcx;
822+
let diag_handler = cgcx.create_diag_handler();
801823

802824
llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0);
803825

@@ -818,7 +840,7 @@ pub(crate) unsafe fn differentiate(
818840
}
819841

820842
for item in diff_items {
821-
let res = enzyme_ad(llmod, llcx, item);
843+
let res = enzyme_ad(llmod, llcx, &diag_handler, item);
822844
assert!(res.is_ok());
823845
}
824846

compiler/rustc_codegen_llvm/src/errors.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ pub enum LlvmError<'a> {
172172
PrepareThinLtoModule,
173173
#[diag(codegen_llvm_parse_bitcode)]
174174
ParseBitcode,
175+
#[diag(codegen_llvm_prepare_autodiff)]
176+
PrepareAutoDiff {
177+
src: String,
178+
target: String,
179+
error: String,
180+
}
175181
}
176182

177183
pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String);
@@ -193,6 +199,7 @@ impl<EM: EmissionGuarantee> IntoDiagnostic<'_, EM> for WithLlvmError<'_> {
193199
}
194200
PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err,
195201
ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err,
202+
PrepareAutoDiff { .. } => fluent::codegen_llvm_prepare_autodiff_with_llvm_err,
196203
};
197204
let mut diag = self.0.into_diagnostic(sess);
198205
diag.set_primary_message(msg_with_llvm_err);

0 commit comments

Comments
 (0)