diff --git a/crates/wasmtime/src/runtime/code.rs b/crates/wasmtime/src/runtime/code.rs index 952f44f7791f..0f824526807c 100644 --- a/crates/wasmtime/src/runtime/code.rs +++ b/crates/wasmtime/src/runtime/code.rs @@ -1,3 +1,4 @@ +use crate::module::GlobalTrapRegistryHandle; use crate::{code_memory::CodeMemory, type_registry::TypeCollection}; use alloc::sync::Arc; use wasmtime_environ::ModuleTypes; @@ -33,18 +34,20 @@ pub struct CodeObject { /// This is either a `ModuleTypes` or a `ComponentTypes` depending on the /// top-level creator of this code. types: Types, + + /// Handle for dropping registration in the global trap registry. + #[allow(dead_code)] + trap_registry_handle: GlobalTrapRegistryHandle, } impl CodeObject { pub fn new(mmap: Arc, signatures: TypeCollection, types: Types) -> CodeObject { - // The corresponding unregister for this is below in `Drop for - // CodeObject`. - crate::module::register_code(&mmap); - + let trap_registry_handle = GlobalTrapRegistryHandle::register_code(mmap.clone()); CodeObject { mmap, signatures, types, + trap_registry_handle, } } @@ -66,12 +69,6 @@ impl CodeObject { } } -impl Drop for CodeObject { - fn drop(&mut self) { - crate::module::unregister_code(&self.mmap); - } -} - pub enum Types { Module(ModuleTypes), #[cfg(feature = "component-model")] diff --git a/crates/wasmtime/src/runtime/code_memory.rs b/crates/wasmtime/src/runtime/code_memory.rs index d0a2f85c00c2..dac6251a739d 100644 --- a/crates/wasmtime/src/runtime/code_memory.rs +++ b/crates/wasmtime/src/runtime/code_memory.rs @@ -8,7 +8,7 @@ use core::ops::Range; use object::SectionFlags; use object::endian::Endianness; use object::read::{Object, ObjectSection, elf::ElfFile64}; -use wasmtime_environ::{Trap, lookup_trap_code, obj}; +use wasmtime_environ::obj; /// Management of executable memory within a `MmapVec` /// @@ -418,12 +418,6 @@ impl CodeMemory { self.debug_registration = Some(reg); Ok(()) } - - /// Looks up the given offset within this module's text section and returns - /// the trap code associated with that instruction, if there is one. - pub fn lookup_trap_code(&self, text_offset: usize) -> Option { - lookup_trap_code(self.trap_data(), text_offset) - } } /// Returns the range of `inner` within `outer`, such that `outer[range]` is the diff --git a/crates/wasmtime/src/runtime/module/registry.rs b/crates/wasmtime/src/runtime/module/registry.rs index 3367167a87e3..c881e2b67e46 100644 --- a/crates/wasmtime/src/runtime/module/registry.rs +++ b/crates/wasmtime/src/runtime/module/registry.rs @@ -9,8 +9,9 @@ use crate::sync::{OnceLock, RwLock}; use crate::{FrameInfo, Module, code_memory::CodeMemory}; use alloc::collections::btree_map::{BTreeMap, Entry}; use alloc::sync::Arc; +use core::ops::Range; use core::ptr::NonNull; -use wasmtime_environ::VMSharedTypeIndex; +use wasmtime_environ::{Trap, VMSharedTypeIndex, lookup_trap_code}; /// Used for registering modules with a store. /// @@ -236,65 +237,104 @@ impl LoadedCode { } } -// This is the global code registry that stores information for all loaded code -// objects that are currently in use by any `Store` in the current process. +// This is the global trap data registry that stores trap section pointer +// for all loaded code objects that are currently in use by any +// `Store` in the current process. // // The purpose of this map is to be called from signal handlers to determine // whether a program counter is a wasm trap or not. Specifically macOS has // no contextual information about the thread available, hence the necessity // for global state rather than using thread local state. // -// This is similar to `ModuleRegistry` except that it has less information and +// This is similar to `ModuleRegistry` except that it stores only trap data and // supports removal. Any time anything is registered with a `ModuleRegistry` // it is also automatically registered with the singleton global module // registry. When a `ModuleRegistry` is destroyed then all of its entries // are removed from the global registry. -fn global_code() -> &'static RwLock { - static GLOBAL_CODE: OnceLock> = OnceLock::new(); - GLOBAL_CODE.get_or_init(Default::default) +fn global_trap_registry() -> &'static RwLock { + static GLOBAL_TRAP_REGISTRY: OnceLock> = OnceLock::new(); + GLOBAL_TRAP_REGISTRY.get_or_init(Default::default) } -type GlobalRegistry = BTreeMap)>; +type GlobalTrapRegistry = BTreeMap)>; /// Find which registered region of code contains the given program counter, and -/// what offset that PC is within that module's code. -pub fn lookup_code(pc: usize) -> Option<(Arc, usize)> { - let all_modules = global_code().read(); - let (_end, (start, module)) = all_modules.range(pc..).next()?; +/// lookup trap for given PC offset within that module code. +pub fn lookup_trap_for_pc(pc: usize) -> Option { + let all_modules = global_trap_registry().read(); + let (_end, (start, _count, trap_range)) = all_modules.range(pc..).next()?; + let trap_data: &[u8]; + unsafe { + // GlobalTrapRegistryHandle ensures that CodeMemory is not dropped + // before unregistration and we're holding registry RwLock here. + trap_data = std::slice::from_raw_parts( + trap_range.start as *const u8, + trap_range.end - trap_range.start, + ); + } let text_offset = pc.checked_sub(*start)?; - Some((module.clone(), text_offset)) + lookup_trap_code(trap_data, text_offset) } -/// Registers a new region of code. -/// -/// Must not have been previously registered and must be `unregister`'d to -/// prevent leaking memory. -/// -/// This is required to enable traps to work correctly since the signal handler -/// will lookup in the `GLOBAL_CODE` list to determine which a particular pc -/// is a trap or not. -pub fn register_code(code: &Arc) { - let text = code.text(); - if text.is_empty() { - return; +pub struct GlobalTrapRegistryHandle { + code: Arc, +} + +impl GlobalTrapRegistryHandle { + /// Registers a new region of code. + /// + /// Multiple `CodeMemory` pointing to the same code might be registered + /// if trap_data pointers are also equal. + /// + /// Returns handle that automatically unregisters this region when dropped. + /// + /// This is required to enable traps to work correctly since the signal handler + /// will lookup in the `GLOBAL_CODE` list to determine which a particular pc + /// is a trap or not. + pub fn register_code(code: Arc) -> Self { + let text = code.text(); + if !text.is_empty() { + let start = text.as_ptr() as usize; + let end = start + text.len() - 1; + let trap_data = code.trap_data().as_ptr_range(); + let trap_range = Range { + start: trap_data.start as usize, + end: trap_data.end as usize, + }; + + let mut locked = global_trap_registry().write(); + let prev = locked.get_mut(&end); + if let Some(prev) = prev { + // Assert that trap_range is equal to previously added entry. + assert_eq!(trap_range, prev.2); + + // Increment usage count. + prev.1 += 1; + } else { + locked.insert(end, (start, 1, trap_range)); + } + } + + Self { code } } - let start = text.as_ptr() as usize; - let end = start + text.len() - 1; - let prev = global_code().write().insert(end, (start, code.clone())); - assert!(prev.is_none()); } -/// Unregisters a code mmap from the global map. -/// -/// Must have been previously registered with `register`. -pub fn unregister_code(code: &Arc) { - let text = code.text(); - if text.is_empty() { - return; +impl Drop for GlobalTrapRegistryHandle { + fn drop(&mut self) { + let text = self.code.text(); + if !text.is_empty() { + let end = (text.as_ptr() as usize) + text.len() - 1; + + let mut locked = global_trap_registry().write(); + let prev = locked.get_mut(&end).unwrap(); + + // Decrement usage count and remove if needed. + prev.1 -= 1; + if prev.1 == 0 { + locked.remove(&end); + } + } } - let end = (text.as_ptr() as usize) + text.len() - 1; - let code = global_code().write().remove(&end); - assert!(code.is_some()); } #[test] diff --git a/crates/wasmtime/src/runtime/vm/sys/unix/machports.rs b/crates/wasmtime/src/runtime/vm/sys/unix/machports.rs index 25bd49ede33c..0e7e618806e1 100644 --- a/crates/wasmtime/src/runtime/vm/sys/unix/machports.rs +++ b/crates/wasmtime/src/runtime/vm/sys/unix/machports.rs @@ -40,7 +40,7 @@ clippy::cast_possible_truncation )] -use crate::runtime::module::lookup_code; +use crate::runtime::module::lookup_trap_for_pc; use crate::runtime::vm::sys::traphandlers::wasmtime_longjmp; use crate::runtime::vm::traphandlers::{TrapRegisters, tls}; use mach2::exc::*; @@ -404,11 +404,7 @@ unsafe fn handle_exception(request: &mut ExceptionRequest) -> bool { // pointer value and if `MAP` changes happen after we read our entry that's // ok since they won't invalidate our entry. let (pc, fp) = get_pc_and_fp(&thread_state); - let Some((code, text_offset)) = lookup_code(pc as usize) else { - return false; - }; - - let Some(trap) = code.lookup_trap_code(text_offset) else { + let Some(trap) = lookup_trap_for_pc(pc as usize) else { return false; }; diff --git a/crates/wasmtime/src/runtime/vm/traphandlers.rs b/crates/wasmtime/src/runtime/vm/traphandlers.rs index 54e05a089d03..aab447a4823c 100644 --- a/crates/wasmtime/src/runtime/vm/traphandlers.rs +++ b/crates/wasmtime/src/runtime/vm/traphandlers.rs @@ -15,7 +15,7 @@ mod signals; #[cfg(all(has_native_signals))] pub use self::signals::*; -use crate::runtime::module::lookup_code; +use crate::runtime::module::lookup_trap_for_pc; use crate::runtime::store::{ExecutorRef, StoreOpaque}; use crate::runtime::vm::sys::traphandlers; use crate::runtime::vm::{InterpreterRef, VMContext, VMStoreContext, f32x4, f64x2, i8x16}; @@ -721,16 +721,10 @@ impl CallThreadState { } } - // If this fault wasn't in wasm code, then it's not our problem - let Some((code, text_offset)) = lookup_code(regs.pc) else { - return TrapTest::NotWasm; - }; - - // If the fault was at a location that was not marked as potentially - // trapping, then that's a bug in Cranelift/Winch/etc. Don't try to - // catch the trap and pretend this isn't wasm so the program likely - // aborts. - let Some(trap) = code.lookup_trap_code(text_offset) else { + // If this fault wasn't in wasm code, then it's not our problem. + // (or it was in wasm code but at a location that was not marked as + // potentially trapping, then that's a bug in Cranelift/Winch/etc.) + let Some(trap) = lookup_trap_for_pc(regs.pc) else { return TrapTest::NotWasm; };