Skip to content

Commit cd58b7f

Browse files
committed
Support loading a wasm module/component via direct mapping
This adds support for directly mapping a host buffer containing a wasm module/component into the guest, enabling the use of mmap() on the host to share a single module/component across multiple sandboxes. Signed-off-by: Lucy Menon <[email protected]>
1 parent e33a667 commit cd58b7f

File tree

4 files changed

+164
-14
lines changed

4 files changed

+164
-14
lines changed

src/hyperlight_wasm/src/sandbox/wasm_sandbox.rs

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
use std::path::Path;
1818

1919
use hyperlight_host::func::call_ctx::MultiUseGuestCallContext;
20+
use hyperlight_host::mem::memory_region::{MemoryRegion, MemoryRegionFlags, MemoryRegionType};
2021
use hyperlight_host::sandbox::Callable;
2122
use hyperlight_host::sandbox_state::sandbox::{EvolvableSandbox, Sandbox};
2223
use hyperlight_host::sandbox_state::transition::MultiUseContextCallback;
@@ -59,34 +60,87 @@ impl WasmSandbox {
5960
/// Before you can call guest functions in the sandbox, you must call
6061
/// this function and use the returned value to call guest functions.
6162
pub fn load_module(self, file: impl AsRef<Path>) -> Result<LoadedWasmSandbox> {
62-
let wasm_bytes = std::fs::read(file)?;
63-
self.load_module_inner(wasm_bytes)
63+
let func = Box::new(move |call_ctx: &mut MultiUseGuestCallContext| {
64+
if let Ok(len) = call_ctx.map_file_cow(file.as_ref(), 0x1_0000_0000) {
65+
call_ctx.call("LoadWasmModulePhys", (0x1_0000_0000u64, len))
66+
} else {
67+
let wasm_bytes = std::fs::read(file)?;
68+
Self::load_module_from_buffer_transition_func(wasm_bytes)(call_ctx)
69+
}
70+
});
71+
self.load_module_inner(func)
6472
}
6573

66-
/// Load a Wasm module from a buffer of bytes into the sandbox and return a `LoadedWasmSandbox`
67-
/// able to execute code in the loaded Wasm Module.
74+
/// Load a Wasm module that is currently present in a buffer in
75+
/// host memory, by mapping the host memory directly into the
76+
/// sandbox.
6877
///
69-
/// Before you can call guest functions in the sandbox, you must call
70-
/// this function and use the returned value to call guest functions.
71-
pub fn load_module_from_buffer(self, buffer: &[u8]) -> Result<LoadedWasmSandbox> {
72-
self.load_module_inner(buffer.to_vec())
78+
/// Depending on the host platform, there are likely alignment
79+
/// requirements of at least one page for base and len
80+
///
81+
/// # Safety
82+
/// It is the caller's responsibility to ensure that the host side
83+
/// of the region remains intact and is not written to until the
84+
/// produced LoadedWasmSandbox is discarded or devolved.
85+
pub unsafe fn load_module_by_mapping(
86+
self,
87+
base: *mut libc::c_void,
88+
len: usize,
89+
) -> Result<LoadedWasmSandbox> {
90+
let func = Box::new(move |call_ctx: &mut MultiUseGuestCallContext| {
91+
let guest_base: usize = 0x1_0000_0000;
92+
let rgn = MemoryRegion {
93+
host_region: base as usize..base.wrapping_add(len) as usize,
94+
guest_region: guest_base..guest_base + len,
95+
flags: MemoryRegionFlags::READ | MemoryRegionFlags::EXECUTE,
96+
region_type: MemoryRegionType::Heap,
97+
};
98+
if let Ok(()) = unsafe { call_ctx.map_region(&rgn) } {
99+
call_ctx.call("LoadWasmModulePhys", (0x1_0000_0000u64, len as u64))
100+
} else {
101+
let wasm_bytes =
102+
unsafe { std::slice::from_raw_parts(base as *const u8, len).to_vec() };
103+
Self::load_module_from_buffer_transition_func(wasm_bytes)(call_ctx)
104+
}
105+
});
106+
self.load_module_inner(func)
73107
}
74108

75-
fn load_module_inner(mut self, wasm_bytes: Vec<u8>) -> Result<LoadedWasmSandbox> {
76-
let func = Box::new(move |call_ctx: &mut MultiUseGuestCallContext| {
77-
let len = wasm_bytes.len() as i32;
78-
let res: i32 = call_ctx.call("LoadWasmModule", (wasm_bytes, len))?;
109+
// todo: take a slice rather than a vec (requires somewhat
110+
// refactoring the flatbuffers stuff maybe)
111+
fn load_module_from_buffer_transition_func(
112+
buffer: Vec<u8>,
113+
) -> impl FnOnce(&mut MultiUseGuestCallContext) -> Result<()> {
114+
move |call_ctx: &mut MultiUseGuestCallContext| {
115+
let len = buffer.len() as i32;
116+
let res: i32 = call_ctx.call("LoadWasmModule", (buffer, len))?;
79117
if res != 0 {
80118
return Err(new_error!(
81119
"LoadWasmModule Failed with error code {:?}",
82120
res
83121
));
84122
}
85123
Ok(())
86-
});
124+
}
125+
}
87126

88-
let transition_func = MultiUseContextCallback::from(func);
127+
/// Load a Wasm module from a buffer of bytes into the sandbox and return a `LoadedWasmSandbox`
128+
/// able to execute code in the loaded Wasm Module.
129+
///
130+
/// Before you can call guest functions in the sandbox, you must call
131+
/// this function and use the returned value to call guest functions.
132+
pub fn load_module_from_buffer(self, buffer: &[u8]) -> Result<LoadedWasmSandbox> {
133+
// TODO: get rid of this clone
134+
let func = Self::load_module_from_buffer_transition_func(buffer.to_vec());
89135

136+
self.load_module_inner(func)
137+
}
138+
139+
fn load_module_inner<F: FnOnce(&mut MultiUseGuestCallContext) -> Result<()>>(
140+
mut self,
141+
func: F,
142+
) -> Result<LoadedWasmSandbox> {
143+
let transition_func = MultiUseContextCallback::from(func);
90144
match self.inner.take() {
91145
Some(sbox) => {
92146
let new_sbox: MultiUseSandbox = sbox.evolve(transition_func)?;

src/wasm_runtime/src/component.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
use alloc::string::ToString;
1818
use alloc::vec;
1919
use alloc::vec::Vec;
20+
use core::ptr::NonNull;
2021
use core::result::Result::*;
2122

2223
use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall;
@@ -72,13 +73,46 @@ fn load_wasm_module(function_call: &FunctionCall) -> Result<Vec<u8>> {
7273
}
7374
}
7475

76+
fn load_wasm_module_phys(function_call: &FunctionCall) -> Result<Vec<u8>> {
77+
if let (ParameterValue::ULong(ref phys), ParameterValue::ULong(ref len), Some(ref engine)) = (
78+
&function_call.parameters.as_ref().unwrap()[0],
79+
&function_call.parameters.as_ref().unwrap()[1],
80+
&*CUR_ENGINE.lock(),
81+
) {
82+
use hyperlight_guest_bin::paging;
83+
// TODO: make sure this address is sensible
84+
let virt = *phys as *mut u8;
85+
let component = unsafe {
86+
paging::map_region(*phys, virt, *len + 4096);
87+
Component::deserialize_raw(
88+
engine,
89+
NonNull::new_unchecked(core::ptr::slice_from_raw_parts_mut(virt, *len as usize)),
90+
)?
91+
};
92+
let mut store = Store::new(engine, ());
93+
let instance = (*CUR_LINKER.lock())
94+
.as_ref()
95+
.unwrap()
96+
.instantiate(&mut store, &component)?;
97+
*CUR_STORE.lock() = Some(store);
98+
*CUR_INSTANCE.lock() = Some(instance);
99+
Ok(get_flatbuffer_result::<()>(()))
100+
} else {
101+
Err(HyperlightGuestError::new(
102+
ErrorCode::GuestFunctionParameterTypeMismatch,
103+
"Invalid parameters passed to LoadWasmModulePhys".to_string(),
104+
))
105+
}
106+
}
107+
75108
#[no_mangle]
76109
pub extern "C" fn hyperlight_main() {
77110
let mut config = Config::new();
78111
config.memory_reservation(0);
79112
config.memory_guard_size(0);
80113
config.memory_reservation_for_growth(0);
81114
config.guard_before_linear_memory(false);
115+
config.with_custom_code_memory(Some(alloc::sync::Arc::new(crate::platform::WasmtimeCodeMemory {})));
82116
let engine = Engine::new(&config).unwrap();
83117
let linker = Linker::new(&engine);
84118
*CUR_ENGINE.lock() = Some(engine);
@@ -98,6 +132,12 @@ pub extern "C" fn hyperlight_main() {
98132
ReturnType::Int,
99133
load_wasm_module as usize,
100134
));
135+
register_function(GuestFunctionDefinition::new(
136+
"LoadWasmModulePhys".to_string(),
137+
vec![ParameterType::ULong, ParameterType::ULong],
138+
ReturnType::Void,
139+
load_wasm_module_phys as usize,
140+
));
101141
}
102142

103143
#[no_mangle]

src/wasm_runtime/src/module.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use alloc::string::ToString;
1818
use alloc::vec::Vec;
1919
use alloc::{format, vec};
2020
use core::ops::Deref;
21+
use core::ptr::NonNull;
2122

2223
use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall;
2324
use hyperlight_common::flatbuffer_wrappers::function_types::{
@@ -93,6 +94,7 @@ fn init_wasm_runtime() -> Result<Vec<u8>> {
9394
config.memory_guard_size(0);
9495
config.memory_reservation_for_growth(0);
9596
config.guard_before_linear_memory(false);
97+
config.with_custom_code_memory(Some(alloc::sync::Arc::new(crate::platform::WasmtimeCodeMemory {})));
9698
let engine = Engine::new(&config)?;
9799
let mut linker = Linker::new(&engine);
98100
wasip1::register_handlers(&mut linker)?;
@@ -138,6 +140,32 @@ fn load_wasm_module(function_call: &FunctionCall) -> Result<Vec<u8>> {
138140
}
139141
}
140142

143+
fn load_wasm_module_phys(function_call: &FunctionCall) -> Result<Vec<u8>> {
144+
if let (ParameterValue::ULong(ref phys), ParameterValue::ULong(ref len), Some(ref engine)) = (
145+
&function_call.parameters.as_ref().unwrap()[0],
146+
&function_call.parameters.as_ref().unwrap()[1],
147+
&*CUR_ENGINE.lock(),
148+
) {
149+
use hyperlight_guest_bin::paging;
150+
// TODO: make sure this address is sensible
151+
let virt = *phys as *mut u8;
152+
let module = unsafe {
153+
paging::map_region(*phys, virt, *len + 4096);
154+
Module::deserialize_raw(
155+
engine,
156+
NonNull::new_unchecked(core::ptr::slice_from_raw_parts_mut(virt, *len as usize)),
157+
)?
158+
};
159+
*CUR_MODULE.lock() = Some(module);
160+
Ok(get_flatbuffer_result::<()>(()))
161+
} else {
162+
Err(HyperlightGuestError::new(
163+
ErrorCode::GuestFunctionParameterTypeMismatch,
164+
"Invalid parameters passed to LoadWasmModulePhys".to_string(),
165+
))
166+
}
167+
}
168+
141169
#[no_mangle]
142170
#[allow(clippy::fn_to_numeric_cast)] // GuestFunctionDefinition expects a function pointer as i64
143171
pub extern "C" fn hyperlight_main() {
@@ -161,4 +189,10 @@ pub extern "C" fn hyperlight_main() {
161189
ReturnType::Int,
162190
load_wasm_module as usize,
163191
));
192+
register_function(GuestFunctionDefinition::new(
193+
"LoadWasmModulePhys".to_string(),
194+
vec![ParameterType::ULong, ParameterType::ULong],
195+
ReturnType::Void,
196+
load_wasm_module_phys as usize,
197+
));
164198
}

src/wasm_runtime/src/platform.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,25 @@ pub extern "C" fn wasmtime_tls_get() -> *mut u8 {
145145
pub extern "C" fn wasmtime_tls_set(ptr: *mut u8) {
146146
FAKE_TLS.store(ptr, Ordering::Release)
147147
}
148+
149+
pub struct WasmtimeCodeMemory {}
150+
// TODO: Actually change the page tables for W^X
151+
impl wasmtime::CustomCodeMemory for WasmtimeCodeMemory {
152+
fn required_alignment(&self) -> usize {
153+
unsafe { hyperlight_guest_bin::OS_PAGE_SIZE as usize }
154+
}
155+
fn publish_executable(
156+
&self,
157+
_ptr: *const u8,
158+
_len: usize,
159+
) -> core::result::Result<(), wasmtime::Error> {
160+
Ok(())
161+
}
162+
fn unpublish_executable(
163+
&self,
164+
_ptr: *const u8,
165+
_len: usize,
166+
) -> core::result::Result<(), wasmtime::Error> {
167+
Ok(())
168+
}
169+
}

0 commit comments

Comments
 (0)