Skip to content

Commit da99fb8

Browse files
committed
Make get_host_value_return_as generic
Signed-off-by: Ludvig Liljenberg <[email protected]>
1 parent da3d827 commit da99fb8

File tree

3 files changed

+18
-95
lines changed

3 files changed

+18
-95
lines changed

src/hyperlight_guest/src/host_function_call.rs

Lines changed: 11 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17+
use alloc::format;
1718
use alloc::string::ToString;
1819
use alloc::vec::Vec;
1920
use core::arch::global_asm;
@@ -39,94 +40,18 @@ pub enum OutBAction {
3940
Abort = 102,
4041
}
4142

42-
pub fn get_host_value_return_as_void() -> Result<()> {
43+
pub fn get_host_value_return_as<T: TryFrom<ReturnValue>>() -> Result<T> {
4344
let return_value = try_pop_shared_input_data_into::<ReturnValue>()
4445
.expect("Unable to deserialize a return value from host");
45-
if let ReturnValue::Void = return_value {
46-
Ok(())
47-
} else {
48-
Err(HyperlightGuestError::new(
49-
ErrorCode::GuestError,
50-
"Host return value was not void as expected".to_string(),
51-
))
52-
}
53-
}
54-
55-
pub fn get_host_value_return_as_int() -> Result<i32> {
56-
let return_value = try_pop_shared_input_data_into::<ReturnValue>()
57-
.expect("Unable to deserialize return value from host");
58-
59-
// check that return value is an int and return
60-
if let ReturnValue::Int(i) = return_value {
61-
Ok(i)
62-
} else {
63-
Err(HyperlightGuestError::new(
64-
ErrorCode::GuestError,
65-
"Host return value was not an int as expected".to_string(),
66-
))
67-
}
68-
}
69-
70-
pub fn get_host_value_return_as_uint() -> Result<u32> {
71-
let return_value = try_pop_shared_input_data_into::<ReturnValue>()
72-
.expect("Unable to deserialize return value from host");
73-
74-
// check that return value is an int and return
75-
if let ReturnValue::UInt(ui) = return_value {
76-
Ok(ui)
77-
} else {
78-
Err(HyperlightGuestError::new(
79-
ErrorCode::GuestError,
80-
"Host return value was not a uint as expected".to_string(),
81-
))
82-
}
83-
}
84-
85-
pub fn get_host_value_return_as_long() -> Result<i64> {
86-
let return_value = try_pop_shared_input_data_into::<ReturnValue>()
87-
.expect("Unable to deserialize return value from host");
88-
89-
// check that return value is an int and return
90-
if let ReturnValue::Long(l) = return_value {
91-
Ok(l)
92-
} else {
93-
Err(HyperlightGuestError::new(
94-
ErrorCode::GuestError,
95-
"Host return value was not a long as expected".to_string(),
96-
))
97-
}
98-
}
99-
100-
pub fn get_host_value_return_as_ulong() -> Result<u64> {
101-
let return_value = try_pop_shared_input_data_into::<ReturnValue>()
102-
.expect("Unable to deserialize return value from host");
103-
104-
// check that return value is an int and return
105-
if let ReturnValue::ULong(ul) = return_value {
106-
Ok(ul)
107-
} else {
108-
Err(HyperlightGuestError::new(
46+
T::try_from(return_value).map_err(|_| {
47+
HyperlightGuestError::new(
10948
ErrorCode::GuestError,
110-
"Host return value was not a ulong as expected".to_string(),
111-
))
112-
}
113-
}
114-
115-
// TODO: Make this generic, return a Result<T, ErrorCode>
116-
117-
pub fn get_host_value_return_as_vecbytes() -> Result<Vec<u8>> {
118-
let return_value = try_pop_shared_input_data_into::<ReturnValue>()
119-
.expect("Unable to deserialize return value from host");
120-
121-
// check that return value is an Vec<u8> and return
122-
if let ReturnValue::VecBytes(v) = return_value {
123-
Ok(v)
124-
} else {
125-
Err(HyperlightGuestError::new(
126-
ErrorCode::GuestError,
127-
"Host return value was not an VecBytes as expected".to_string(),
128-
))
129-
}
49+
format!(
50+
"Host return value was not a {} as expected",
51+
core::any::type_name::<T>()
52+
),
53+
)
54+
})
13055
}
13156

13257
// TODO: Make this generic, return a Result<T, ErrorCode> this should allow callers to call this function and get the result type they expect
@@ -194,7 +119,7 @@ pub fn print_output_as_guest_function(function_call: &FunctionCall) -> Result<Ve
194119
Some(Vec::from(&[ParameterValue::String(message.to_string())])),
195120
ReturnType::Int,
196121
)?;
197-
let res_i = get_host_value_return_as_int()?;
122+
let res_i = get_host_value_return_as::<i32>()?;
198123
Ok(get_flatbuffer_result(res_i))
199124
} else {
200125
Err(HyperlightGuestError::new(

src/tests/rust_guests/callbackguest/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use hyperlight_guest::error::{HyperlightGuestError, Result};
3535
use hyperlight_guest::guest_function_definition::GuestFunctionDefinition;
3636
use hyperlight_guest::guest_function_register::register_function;
3737
use hyperlight_guest::host_function_call::{
38-
call_host_function, get_host_value_return_as_int, print_output_as_guest_function,
38+
call_host_function, get_host_value_return_as, print_output_as_guest_function,
3939
};
4040
use hyperlight_guest::logging::log_message;
4141

@@ -51,7 +51,7 @@ fn send_message_to_host_method(
5151
ReturnType::Int,
5252
)?;
5353

54-
let result = get_host_value_return_as_int()?;
54+
let result = get_host_value_return_as::<i32>()?;
5555

5656
Ok(get_flatbuffer_result(result))
5757
}

src/tests/rust_guests/simpleguest/src/main.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ use hyperlight_guest::entrypoint::{abort_with_code, abort_with_code_and_message}
4545
use hyperlight_guest::error::{HyperlightGuestError, Result};
4646
use hyperlight_guest::guest_function_definition::GuestFunctionDefinition;
4747
use hyperlight_guest::guest_function_register::register_function;
48-
use hyperlight_guest::host_function_call::{
49-
call_host_function, get_host_value_return_as_int, get_host_value_return_as_ulong,
50-
};
48+
use hyperlight_guest::host_function_call::{call_host_function, get_host_value_return_as};
5149
use hyperlight_guest::memory::malloc;
5250
use hyperlight_guest::{logging, MIN_STACK_ADDRESS};
5351
use log::{error, LevelFilter};
@@ -94,7 +92,7 @@ fn print_output(message: &str) -> Result<Vec<u8>> {
9492
Some(Vec::from(&[ParameterValue::String(message.to_string())])),
9593
ReturnType::Int,
9694
)?;
97-
let result = get_host_value_return_as_int()?;
95+
let result = get_host_value_return_as::<i32>()?;
9896
Ok(get_flatbuffer_result(result))
9997
}
10098

@@ -674,7 +672,7 @@ fn violate_seccomp_filters(function_call: &FunctionCall) -> Result<Vec<u8>> {
674672
if function_call.parameters.is_none() {
675673
call_host_function("MakeGetpidSyscall", None, ReturnType::ULong)?;
676674

677-
let res = get_host_value_return_as_ulong()?;
675+
let res = get_host_value_return_as::<u64>()?;
678676

679677
Ok(get_flatbuffer_result(res))
680678
} else {
@@ -696,7 +694,7 @@ fn add(function_call: &FunctionCall) -> Result<Vec<u8>> {
696694
ReturnType::Int,
697695
)?;
698696

699-
let res = get_host_value_return_as_int()?;
697+
let res = get_host_value_return_as::<i32>()?;
700698

701699
Ok(get_flatbuffer_result(res))
702700
} else {
@@ -1115,7 +1113,7 @@ pub fn guest_dispatch_function(function_call: FunctionCall) -> Result<Vec<u8>> {
11151113
Some(Vec::from(&[ParameterValue::String(message.to_string())])),
11161114
ReturnType::Int,
11171115
)?;
1118-
let result = get_host_value_return_as_int()?;
1116+
let result = get_host_value_return_as::<i32>()?;
11191117
let function_name = function_call.function_name.clone();
11201118
let param_len = function_call.parameters.clone().unwrap_or_default().len();
11211119
let call_type = function_call.function_call_type().clone();

0 commit comments

Comments
 (0)