diff --git a/src/llvm/di.rs b/src/llvm/di.rs index dbaa0f29..18abdd11 100644 --- a/src/llvm/di.rs +++ b/src/llvm/di.rs @@ -1,6 +1,8 @@ +use core::slice; use std::{ borrow::Cow, collections::{HashMap, HashSet, hash_map::DefaultHasher}, + ffi::CStr, hash::Hasher as _, io::Write as _, marker::PhantomData, @@ -8,14 +10,15 @@ use std::{ }; use gimli::{DW_TAG_pointer_type, DW_TAG_structure_type, DW_TAG_variant_part}; -use llvm_sys::{core::*, debuginfo::*, prelude::*}; +use llvm_sys::{LLVMTypeKind, core::*, debuginfo::*, prelude::*}; +use log::debug; use tracing::{Level, span, trace, warn}; use super::types::{ di::DIType, ir::{Function, MDNode, Metadata, Value}, }; -use crate::llvm::{LLVMContext, LLVMModule, iter::*, types::di::DISubprogram}; +use crate::llvm::{LLVMContext, LLVMModule, iter::*, symbol_name, types::{di::DISubprogram, ir::MetadataEntries}}; // KSYM_NAME_LEN from linux kernel intentionally set // to lower value found across kernel versions to ensure @@ -268,6 +271,9 @@ impl<'ctx> DISanitizer<'ctx> { pub(crate) fn run(mut self, exported_symbols: &HashSet>) { let module = self.module; + // Create debug info for extern functions first + self.create_extern_debug_info(); + self.replace_operands = self.fix_subprogram_linkage(exported_symbols); for value in module.globals_iter() { @@ -319,6 +325,11 @@ impl<'ctx> DISanitizer<'ctx> { continue; } + let num_blocks = unsafe { LLVMCountBasicBlocks(function.value_ref) }; + if num_blocks == 0 { + continue; + } + // Skip functions that don't have subprograms. let Some(mut subprogram) = function.subprogram(self.context) else { continue; @@ -389,6 +400,243 @@ impl<'ctx> DISanitizer<'ctx> { replace } + + fn create_extern_debug_info(&mut self) { + let Some((_, di_file)) = self.get_compile_unit_and_file() else { + warn!("No compile unit found, skipping extern debug info creation"); + return; + }; + + let functions: Vec = self.module.functions_iter().collect(); + + for function in functions { + let mut func = unsafe { Function::from_value_ref(function) }; + + if func.subprogram(self.context).is_some() { + continue; + } + + // Check if it's an extern (0 basic blocks) + let num_blocks = unsafe { LLVMCountBasicBlocks(function) }; + if num_blocks > 0 { + continue; + } + + let name = func.name(); + + // Get function type + let func_type = unsafe { LLVMGlobalGetValueType(function) }; + let return_type = unsafe { LLVMGetReturnType(func_type) }; + let return_type_kind = unsafe { LLVMGetTypeKind(return_type) }; + let param_count = unsafe { LLVMCountParamTypes(func_type) }; + + if !return_type.is_null() { + warn!("Return type kind: {:?}", return_type_kind); + } + // Create DITypes for return and params + let mut di_types = Vec::new(); + + // Add return type as first element + let di_return = self.create_di_type_from_llvm_type(return_type, di_file); + di_types.push(di_return); + + // Add parameter types and collect them for later + let mut param_di_types = Vec::new(); + if param_count > 0 { + let mut param_types = vec![ptr::null_mut(); param_count as usize]; + unsafe { LLVMGetParamTypes(func_type, param_types.as_mut_ptr()) }; + + for param_type in param_types { + let di_param = self.create_di_type_from_llvm_type(param_type, di_file); + di_types.push(di_param); + param_di_types.push(di_param); + } + } + + // Create DISubroutineType + let di_subroutine_type = unsafe { + LLVMDIBuilderCreateSubroutineType( + self.builder, + di_file, + di_types.as_mut_ptr(), + di_types.len() as u32, + 0, + ) + }; + + // Create DISubprogram for extern (declaration, not definition) + let subprogram = unsafe { + LLVMDIBuilderCreateFunction( + self.builder, + di_file, + name.as_ptr().cast(), + name.len(), + name.as_ptr().cast(), + name.len(), + di_file, + 0, + di_subroutine_type, + 0, + 0, + 0, + LLVMDIFlagPrototyped, + 1, + ) + }; + + let mut di_subprogram = unsafe { + DISubprogram::from_value_ref(LLVMMetadataAsValue(self.context, subprogram)) + }; + + // Create parameter debug info for retained nodes + if !param_di_types.is_empty() { + let mut param_vars = Vec::new(); + + for (idx, di_param_type) in param_di_types.iter().enumerate() { + let arg_idx = (idx + 1) as u32; + let param_name = format!("arg{}", idx); + + let di_param_var = unsafe { + LLVMDIBuilderCreateParameterVariable( + self.builder, + subprogram, // scope + param_name.as_ptr().cast(), + param_name.len(), + arg_idx, + di_file, + 0, // line + *di_param_type, + 1, // always preserve + 0, // flags + ) + }; + param_vars.push(di_param_var); + } + + // Create retained nodes metadata + let retained_nodes = unsafe { + LLVMMDNodeInContext2(self.context, param_vars.as_mut_ptr(), param_vars.len()) + }; + di_subprogram.set_retained_nodes(retained_nodes); + } + + unsafe { LLVMDIBuilderFinalizeSubprogram(self.builder, subprogram) }; + + func.set_subprogram(&di_subprogram); + } + } + fn create_di_type_from_llvm_type( + &mut self, + llvm_type: LLVMTypeRef, + di_file: LLVMMetadataRef, + ) -> LLVMMetadataRef { + unsafe { + let type_kind = LLVMGetTypeKind(llvm_type); + + match type_kind { + LLVMTypeKind::LLVMVoidTypeKind => { + LLVMDIBuilderCreateBasicType(self.builder, c"void".as_ptr(), 4, 0, 0, 0) + } + LLVMTypeKind::LLVMIntegerTypeKind => { + self.create_di_basic_int(llvm_type, di_file) + } + LLVMTypeKind::LLVMPointerTypeKind => { + // Create void* for simplicity + let pointee = + LLVMDIBuilderCreateBasicType(self.builder, c"void".as_ptr(), 4, 0, 0, 0); + LLVMDIBuilderCreatePointerType( + self.builder, + pointee, + 64, // BPF is 64-bit + 0, // align + 0, // address space + c"".as_ptr(), + 0, + ) + } + LLVMTypeKind::LLVMStructTypeKind => { + // Create opaque struct type for extern function parameters + // We don't need full layout for externs + let struct_name = { + let name_ptr = LLVMGetStructName(llvm_type); + if name_ptr.is_null() { + c"struct" + } else { + CStr::from_ptr(name_ptr) + } + }; + + LLVMDIBuilderCreateStructType( + self.builder, + ptr::null_mut(), // scope + struct_name.as_ptr(), + struct_name.to_bytes().len(), + di_file, + 0, // line + 0, // size (opaque) + 0, // align + LLVMDIFlagFwdDecl, // forward decl + ptr::null_mut(), // derived from + ptr::null_mut(), // elements + 0, // element count + 0, // runtime lang + ptr::null_mut(), // vtable + c"".as_ptr(), + 0, + ) + } + // For any other type, default to void + _ => LLVMDIBuilderCreateBasicType(self.builder, c"void".as_ptr(), 4, 0, 0, 0), + } + } + } + + fn create_di_basic_int( + &mut self, + llvm_type: LLVMTypeRef, + _di_file: LLVMMetadataRef, + ) -> LLVMMetadataRef { + unsafe { + let width = LLVMGetIntTypeWidth(llvm_type); + + // DWARF encoding values + const DW_ATE_BOOLEAN: u32 = 0x02; + const DW_ATE_SIGNED: u32 = 0x05; + const DW_ATE_UNSIGNED: u32 = 0x07; + + let (name, encoding) = match width { + 1 => (c"bool", DW_ATE_BOOLEAN), + 8 => (c"u8", DW_ATE_UNSIGNED), + 16 => (c"u16", DW_ATE_UNSIGNED), + 32 => (c"i32", DW_ATE_SIGNED), + 64 => (c"u64", DW_ATE_UNSIGNED), + _ => (c"int", DW_ATE_SIGNED), + }; + + LLVMDIBuilderCreateBasicType( + self.builder, + name.as_ptr(), + name.to_bytes().len(), + width as u64, + encoding, + 0, // flags + ) + } + } + + fn get_compile_unit_and_file(&self) -> Option<(LLVMMetadataRef, LLVMMetadataRef)> { + for function in self.module.functions_iter() { + let func = unsafe { Function::from_value_ref(function) }; + + if let Some(subprogram) = func.subprogram(self.context) { + if let Some(unit) = subprogram.unit() { + let file = subprogram.file(); + return Some((unit, file)); + } + } + } + None + } } #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs index 5e644084..81c9e220 100644 --- a/src/llvm/mod.rs +++ b/src/llvm/mod.rs @@ -16,10 +16,11 @@ use llvm_sys::{ LLVMAttributeFunctionIndex, LLVMLinkage, LLVMVisibility, bit_reader::LLVMParseBitcodeInContext2, core::{ - LLVMCreateMemoryBufferWithMemoryRange, LLVMDisposeMemoryBuffer, LLVMDisposeMessage, - LLVMGetEnumAttributeKindForName, LLVMGetMDString, LLVMGetModuleInlineAsm, LLVMGetTarget, - LLVMGetValueName2, LLVMRemoveEnumAttributeAtIndex, LLVMSetLinkage, LLVMSetModuleInlineAsm2, - LLVMSetVisibility, + LLVMCountBasicBlocks, LLVMCreateMemoryBufferWithMemoryRange, LLVMDisposeMemoryBuffer, + LLVMDisposeMessage, LLVMGetEnumAttributeKindForName, LLVMGetMDString, + LLVMGetModuleInlineAsm, LLVMGetTarget, LLVMGetValueName2, LLVMIsAFunction, + LLVMIsAGlobalVariable, LLVMIsDeclaration, LLVMRemoveEnumAttributeAtIndex, LLVMSetLinkage, + LLVMSetModuleInlineAsm2, LLVMSetSection, LLVMSetVisibility, }, error::{ LLVMDisposeErrorMessage, LLVMGetErrorMessage, LLVMGetErrorTypeId, LLVMGetStringErrorTypeId, @@ -41,6 +42,7 @@ use llvm_sys::{ LLVMCreatePassBuilderOptions, LLVMDisposePassBuilderOptions, LLVMRunPasses, }, }; +use log::info; use tracing::{debug, error}; pub(crate) use types::{ context::{InstalledDiagnosticHandler, LLVMContext}, @@ -261,6 +263,32 @@ pub(crate) fn internalize( export_symbols: &HashSet>, ) { if !name.starts_with(b"llvm.") && !export_symbols.contains(name) { + if unsafe { !LLVMIsAFunction(value).is_null() } { + let num_blocks = unsafe { LLVMCountBasicBlocks(value) }; + if num_blocks == 0 { + unsafe { LLVMSetSection(value, c".ksyms".as_ptr()) }; + unsafe { LLVMSetLinkage(value, LLVMLinkage::LLVMExternalLinkage); } + unsafe { LLVMSetVisibility(value, LLVMVisibility::LLVMDefaultVisibility); } + info!( + "not internalizing undefined function {}", + str::from_utf8(name).unwrap_or("") + ); + return; + } + } + if unsafe { !LLVMIsAGlobalVariable(value).is_null() } { + if unsafe { LLVMIsDeclaration(value) != 0 } { + unsafe { LLVMSetSection(value, c".ksyms".as_ptr()) }; + unsafe { LLVMSetLinkage(value, LLVMLinkage::LLVMExternalLinkage); } + unsafe { LLVMSetVisibility(value, LLVMVisibility::LLVMDefaultVisibility); } + info!( + "not internalizing undefined global variable {}", + str::from_utf8(name).unwrap_or("") + ); + return; + } + } + unsafe { LLVMSetLinkage(value, LLVMLinkage::LLVMInternalLinkage) }; unsafe { LLVMSetVisibility(value, LLVMVisibility::LLVMDefaultVisibility) }; } diff --git a/tests/assembly/extern_linkage.rs b/tests/assembly/extern_linkage.rs new file mode 100644 index 00000000..06aa0341 --- /dev/null +++ b/tests/assembly/extern_linkage.rs @@ -0,0 +1,43 @@ +// assembly-output: bpf-linker +// compile-flags: --crate-type cdylib -C link-arg=--emit=llvm-ir + +#![no_std] + +// aux-build: loop-panic-handler.rs +extern crate loop_panic_handler; + +// Extern declarations +extern "C" { + fn bpf_kfunc_call_test_acquire(arg: *mut u64) -> *mut u64; + fn bpf_kfunc_call_test_release(arg: *mut u64); + static bpf_prog_active: u32; + static CONFIG_HZ: u64; +} + +#[no_mangle] +#[link_section = "tc"] +pub fn test_extern_symbols() -> u64 { + unsafe { + let mut val: u64 = 42; + let ptr = bpf_kfunc_call_test_acquire(&mut val as *mut u64); + bpf_kfunc_call_test_release(ptr); + + let active = core::ptr::read_volatile(&bpf_prog_active); + let hz = core::ptr::read_volatile(&CONFIG_HZ); + active as u64 + hz + } +} + + +// Verify extern variables: external, not internal +// CHECK: @bpf_prog_active = external{{.*}}global i32{{.*}}section ".ksyms" +// CHECK: @CONFIG_HZ = external{{.*}}global i64{{.*}}section ".ksyms" +// CHECK-NOT: @bpf_prog_active = internal +// CHECK-NOT: @CONFIG_HZ = internal +// Verify extern functions preserve linkage/calling convention/function signature +// CHECK: declare ptr @bpf_kfunc_call_test_acquire(ptr){{.*}}section ".ksyms" +// CHECK: declare void @bpf_kfunc_call_test_release(ptr){{.*}}section ".ksyms" +// CHECK-NOT: declare internal{{.*}}@bpf_kfunc_call_test_acquire unnamed_addr #0 +// CHECK-NOT: declare internal{{.*}}@bpf_kfunc_call_test_release unnamed_addr #0 +// CHECK-NOT: declare{{.*}}fastcc{{.*}}@bpf_kfunc_call_test_acquire unnamed_addr #0 +// CHECK-NOT: declare{{.*}}fastcc{{.*}}@bpf_kfunc_call_test_release unnamed_addr #0