diff --git a/src/wasm_runtime/src/module.rs b/src/wasm_runtime/src/module.rs index e3fa630..a030082 100644 --- a/src/wasm_runtime/src/module.rs +++ b/src/wasm_runtime/src/module.rs @@ -120,6 +120,39 @@ fn init_wasm_runtime() -> Result> { Ok(get_flatbuffer_result::(0)) } +// Helper function to handle common module loading logic after module creation +fn instantiate_module(module: Module, engine: &Engine) -> Result<()> { + let linker = CUR_LINKER.lock(); + let linker = linker.deref().as_ref().ok_or(HyperlightGuestError::new( + ErrorCode::GuestError, + "impossible: wasm runtime has no valid linker".to_string(), + ))?; + + let mut store = Store::new(engine, ()); + let instance = linker.instantiate(&mut store, &module)?; + + // Module must export malloc and free as these + // are used to marshal param/return values + if instance.get_func(&mut store, "malloc").is_none() { + return Err(HyperlightGuestError::new( + ErrorCode::GuestError, + "WASM module must export 'malloc' function".to_string(), + )); + } + + if instance.get_func(&mut store, "free").is_none() { + return Err(HyperlightGuestError::new( + ErrorCode::GuestError, + "WASM module must export 'free' function".to_string(), + )); + } + + *CUR_MODULE.lock() = Some(module); + *CUR_STORE.lock() = Some(store); + *CUR_INSTANCE.lock() = Some(instance); + Ok(()) +} + fn load_wasm_module(function_call: &FunctionCall) -> Result> { if let ( ParameterValue::VecBytes(ref wasm_bytes), @@ -130,19 +163,8 @@ fn load_wasm_module(function_call: &FunctionCall) -> Result> { &function_call.parameters.as_ref().unwrap()[1], &*CUR_ENGINE.lock(), ) { - let linker = CUR_LINKER.lock(); - let linker = linker.deref().as_ref().ok_or(HyperlightGuestError::new( - ErrorCode::GuestError, - "impossible: wasm runtime has no valid linker".to_string(), - ))?; - let module = unsafe { Module::deserialize(engine, wasm_bytes)? }; - let mut store = Store::new(engine, ()); - let instance = linker.instantiate(&mut store, &module)?; - - *CUR_MODULE.lock() = Some(module); - *CUR_STORE.lock() = Some(store); - *CUR_INSTANCE.lock() = Some(instance); + instantiate_module(module, engine)?; Ok(get_flatbuffer_result::(0)) } else { Err(HyperlightGuestError::new( @@ -158,19 +180,8 @@ fn load_wasm_module_phys(function_call: &FunctionCall) -> Result> { &function_call.parameters.as_ref().unwrap()[1], &*CUR_ENGINE.lock(), ) { - let linker = CUR_LINKER.lock(); - let linker = linker.deref().as_ref().ok_or(HyperlightGuestError::new( - ErrorCode::GuestError, - "impossible: wasm runtime has no valid linker".to_string(), - ))?; - let module = unsafe { Module::deserialize_raw(engine, platform::map_buffer(*phys, *len))? }; - let mut store = Store::new(engine, ()); - let instance = linker.instantiate(&mut store, &module)?; - - *CUR_MODULE.lock() = Some(module); - *CUR_STORE.lock() = Some(store); - *CUR_INSTANCE.lock() = Some(instance); + instantiate_module(module, engine)?; Ok(get_flatbuffer_result::<()>(())) } else { Err(HyperlightGuestError::new(