diff --git a/Cargo.lock b/Cargo.lock index 61de611a..d9bef42c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -119,6 +119,7 @@ dependencies = [ "log", "regex", "rustc-build-sysroot", + "tempfile", "thiserror 2.0.17", "tracing", "tracing-appender", diff --git a/Cargo.toml b/Cargo.toml index 1a506ce9..3dea6514 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ tracing = "0.1" compiletest_rs = { version = "0.11.0" } regex = { version = "1.11.1", default-features = false } rustc-build-sysroot = { workspace = true } +tempfile = { version = "3.13" } which = { version = "8.0.0", default-features = false, features = ["real-sys", "regex"] } [lints] diff --git a/src/linker.rs b/src/linker.rs index faea046b..7dc42336 100644 --- a/src/linker.rs +++ b/src/linker.rs @@ -206,6 +206,8 @@ enum InputType { MachO, /// Archive file. (.a) Archive, + /// IR file (.ll) + Ir, } impl std::fmt::Display for InputType { @@ -218,6 +220,7 @@ impl std::fmt::Display for InputType { Self::Elf => "elf", Self::MachO => "Mach-O", Self::Archive => "archive", + Self::Ir => "ir", } ) } @@ -507,24 +510,18 @@ where .create_module(c"linked_module") .ok_or(LinkerError::CreateModuleError)?; - // buffer used to perform file type detection - let mut buf = [0u8; 8]; for mut input in inputs { - let path = match input { - InputReader::File { path, .. } => path.into(), + let path = match &input { + InputReader::File { path, .. } => (*path).into(), InputReader::Buffer { name, .. } => PathBuf::from(format!("in_memory::{}", name)), }; - // determine whether the input is bitcode, ELF with embedded bitcode, an archive file - // or an invalid file - input - .read_exact(&mut buf) - .map_err(|e| LinkerError::IoError(path.clone(), e))?; + let in_type = detect_input_type(&mut input) + .ok_or_else(|| LinkerError::InvalidInputType(path.clone()))?; + input .rewind() .map_err(|e| LinkerError::IoError(path.clone(), e))?; - let in_type = - detect_input_type(&buf).ok_or_else(|| LinkerError::InvalidInputType(path.clone()))?; match in_type { InputType::Archive => { @@ -584,16 +581,33 @@ fn link_reader<'ctx>( .map_err(|e| LinkerError::IoError(path.to_owned(), e))?; // in_type is unknown when we're linking an item from an archive file let in_type = in_type - .or_else(|| detect_input_type(&data)) + .or_else(|| detect_input_type(reader.by_ref())) .ok_or_else(|| LinkerError::InvalidInputType(path.to_owned()))?; - let bitcode = match in_type { - InputType::Bitcode => data, - InputType::Elf => match llvm::find_embedded_bitcode(context, &data) { - Ok(Some(bitcode)) => bitcode, - Ok(None) => return Err(LinkerError::MissingBitcodeSection(path.to_owned())), - Err(e) => return Err(LinkerError::EmbeddedBitcodeError(e)), - }, + match in_type { + InputType::Bitcode => { + if !llvm::link_bitcode_buffer(context, module, &data) { + return Err(LinkerError::LinkModuleError(path.to_owned())); + } + } + InputType::Ir => { + let data = CString::new(data).unwrap(); + if !llvm::link_ir_buffer(context, module, &data) + .map_err(|_| LinkerError::LinkModuleError(path.to_owned()))? + { + return Err(LinkerError::LinkModuleError(path.to_owned())); + } + } + InputType::Elf => { + let bitcode = match llvm::find_embedded_bitcode(context, &data) { + Ok(Some(bitcode)) => bitcode, + Ok(None) => return Err(LinkerError::MissingBitcodeSection(path.to_owned())), + Err(e) => return Err(LinkerError::EmbeddedBitcodeError(e)), + }; + if !llvm::link_bitcode_buffer(context, module, &bitcode) { + return Err(LinkerError::LinkModuleError(path.to_owned())); + } + } // we need to handle this here since archive files could contain // mach-o files, eg somecrate.rlib containing lib.rmeta which is // mach-o on macos @@ -602,10 +616,6 @@ fn link_reader<'ctx>( InputType::Archive => panic!("nested archives not supported duh"), }; - if !llvm::link_bitcode_buffer(context, module, &bitcode) { - return Err(LinkerError::LinkModuleError(path.to_owned())); - } - Ok(()) } @@ -870,18 +880,23 @@ impl llvm::LLVMDiagnosticHandler for DiagnosticHandler { } } -fn detect_input_type(data: &[u8]) -> Option { - if data.len() < 8 { +fn detect_input_type(reader: &mut impl Read) -> Option { + let mut header = [0u8; 16]; + let bytes_read = reader.read(&mut header).ok()?; + + if bytes_read < 4 { return None; } - match &data[..4] { + match &header[..4] { b"\x42\x43\xC0\xDE" | b"\xDE\xC0\x17\x0b" => Some(InputType::Bitcode), b"\x7FELF" => Some(InputType::Elf), b"\xcf\xfa\xed\xfe" => Some(InputType::MachO), _ => { - if &data[..8] == b"!\x0A" { + if bytes_read >= 8 && &header[..8] == b"!\x0A" { Some(InputType::Archive) + } else if is_llvm_ir(&header[..bytes_read]) { + Some(InputType::Ir) } else { None } @@ -889,6 +904,22 @@ fn detect_input_type(data: &[u8]) -> Option { } } +fn is_llvm_ir(data: &[u8]) -> bool { + let trimmed = data.trim_ascii_start(); + + let prefixes: &[&[u8]] = &[ + b"; ModuleID", + b"target triple", + b"target datalayout", + b"source_filename", + b"target ", + b"define", + b"!llvm", + ]; + + prefixes.iter().any(|prefix| trimmed.starts_with(prefix)) +} + pub struct LinkerOutput { inner: MemoryBuffer, } diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs index 5e644084..b84886c4 100644 --- a/src/llvm/mod.rs +++ b/src/llvm/mod.rs @@ -24,6 +24,7 @@ use llvm_sys::{ error::{ LLVMDisposeErrorMessage, LLVMGetErrorMessage, LLVMGetErrorTypeId, LLVMGetStringErrorTypeId, }, + ir_reader::LLVMParseIRInContext, linker::LLVMLinkModules2, object::{ LLVMCreateBinary, LLVMDisposeBinary, LLVMDisposeSectionIterator, LLVMGetSectionContents, @@ -141,6 +142,41 @@ pub(crate) fn link_bitcode_buffer<'ctx>( linked } +pub(crate) fn link_ir_buffer<'ctx>( + context: &'ctx LLVMContext, + module: &mut LLVMModule<'ctx>, + buffer: &CStr, +) -> Result { + let buffer_name = c"ir_buffer"; + let buffer = buffer.to_bytes(); + let mem_buffer = unsafe { + LLVMCreateMemoryBufferWithMemoryRange( + buffer.as_ptr().cast(), + buffer.len(), + buffer_name.as_ptr(), + 1, // LLVM internally sets RequiresTerminator=true + ) + }; + + let mut temp_module = ptr::null_mut(); + let (ret, message) = Message::with(|error_msg| unsafe { + LLVMParseIRInContext( + context.as_mut_ptr(), + mem_buffer, + &mut temp_module, + error_msg, + ) + }); + + if ret == 0 { + let linked = unsafe { LLVMLinkModules2(module.as_mut_ptr(), temp_module) } == 0; + Ok(linked) + } else { + unsafe { LLVMDisposeMemoryBuffer(mem_buffer) }; + Err(message.as_string_lossy().to_string()) + } +} + pub(crate) fn target_from_triple(triple: &CStr) -> Result { let mut target = ptr::null_mut(); let (ret, message) = Message::with(|message| unsafe { diff --git a/tests/ir_file_test.rs b/tests/ir_file_test.rs new file mode 100644 index 00000000..4de122a4 --- /dev/null +++ b/tests/ir_file_test.rs @@ -0,0 +1,100 @@ +#![expect(unused_crate_dependencies, reason = "used in lib/bin")] + +use std::{ + env, fs, + path::{Path, PathBuf}, + process::Command, +}; + +fn linker_path() -> PathBuf { + PathBuf::from(env!("CARGO_BIN_EXE_bpf-linker")) +} + +fn create_test_ir_file(dir: &Path, name: &str) -> PathBuf { + let ir_path = dir.join(format!("{}.ll", name)); + let ir_content = format!( + r#"; ModuleID = '{name}' +source_filename = "{name}" +target datalayout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128" +target triple = "bpf" + +define i32 @test_{name}(i32 %x) #0 {{ +entry: + %result = add i32 %x, 1 + ret i32 %result +}} + +attributes #0 = {{ noinline nounwind optnone }} + +!llvm.module.flags = !{{!0}} +!0 = !{{i32 1, !"wchar_size", i32 4}} +"# + ); + fs::write(&ir_path, ir_content).expect("Failed to write test IR file"); + ir_path +} + +#[test] +fn test_link_ir_file() { + let temp_dir = tempfile::tempdir().expect("Failed to create temp dir"); + let ir_file = create_test_ir_file(temp_dir.path(), "alessandro"); + let output_file = temp_dir.path().join("output.o"); + + let output = Command::new(linker_path()) + .arg("--export") + .arg(format!("test_{}", "alessandro")) + .arg(&ir_file) + .arg("-o") + .arg(&output_file) + .output() + .expect("Failed to execute bpf-linker"); + + if !output.status.success() { + eprintln!("stdout: {}", String::from_utf8_lossy(&output.stdout)); + eprintln!("stderr: {}", String::from_utf8_lossy(&output.stderr)); + panic!("bpf-linker failed with status: {}", output.status); + } + + assert!( + output_file.exists(), + "Output file should exist: {:?}", + output_file + ); + assert!( + output_file.metadata().unwrap().len() > 0, + "Output file should not be empty" + ); +} + +#[test] +fn test_invalid_ir_file() { + let temp_dir = tempfile::tempdir().expect("Failed to create temp dir"); + + let valid_ir_file = create_test_ir_file(temp_dir.path(), "alessandro"); + + let valid_content = fs::read_to_string(valid_ir_file).expect("Failed to read valid IR file"); + + // Corrupting IR content + let invalid_content = + valid_content.replace("; ModuleID = 'alessandro'", ": ModuleXX = 'corrupted'"); + + let invalid_ir_file = temp_dir.path().join("corrupted.ll"); + + fs::write(&invalid_ir_file, invalid_content).expect("Failed to write invalid IR file"); + + let output_file = temp_dir.path().join("output.o"); + + let output = Command::new(linker_path()) + .arg(&invalid_ir_file) + .arg("-o") + .arg(&output_file) + .output() + .expect("Failed to execute bpf-linker"); + + // Should fail with corrupted IR + assert!( + !output.status.success(), + "bpf-linker should fail with corrupted IR. stderr: {}", + String::from_utf8_lossy(&output.stderr) + ); +}