diff --git a/src/stack_limiter/max_height.rs b/src/stack_limiter/max_height.rs index 7f01b516..7212bc8c 100644 --- a/src/stack_limiter/max_height.rs +++ b/src/stack_limiter/max_height.rs @@ -1,16 +1,10 @@ use super::resolve_func_type; use alloc::vec::Vec; -use parity_wasm::elements::{self, BlockType, Type}; +use parity_wasm::elements::{self, BlockType, Instruction, Type}; #[cfg(feature = "sign_ext")] use parity_wasm::elements::SignExtInstruction; -// The cost in stack items that should be charged per call of a function. This is -// is a static cost that is added to each function call. This makes sense because even -// if a function does not use any parameters or locals some stack space on the host -// machine might be consumed to hold some context. -const ACTIVATION_FRAME_COST: u32 = 2; - /// Control stack frame. #[derive(Debug)] struct Frame { @@ -41,8 +35,8 @@ struct Stack { } impl Stack { - fn new() -> Stack { - Stack { height: ACTIVATION_FRAME_COST, control_stack: Vec::new() } + fn new() -> Self { + Self { height: 0, control_stack: Vec::new() } } /// Returns current height of the value stack. @@ -121,58 +115,135 @@ impl Stack { } } -/// This function expects the function to be validated. -pub fn compute(func_idx: u32, module: &elements::Module) -> Result { - use parity_wasm::elements::Instruction::*; - - let func_section = module.function_section().ok_or("No function section")?; - let code_section = module.code_section().ok_or("No code section")?; - let type_section = module.type_section().ok_or("No type section")?; - - // Get a signature and a body of the specified function. - let func_sig_idx = func_section - .entries() - .get(func_idx as usize) - .ok_or("Function is not found in func section")? - .type_ref(); - let Type::Function(func_signature) = type_section - .types() - .get(func_sig_idx as usize) - .ok_or("Function is not found in func section")?; - let body = code_section - .bodies() - .get(func_idx as usize) - .ok_or("Function body for the index isn't found")?; - let instructions = body.code(); - - let mut stack = Stack::new(); - let mut max_height: u32 = 0; - let mut pc = 0; - - // Add implicit frame for the function. Breaks to this frame and execution of - // the last end should deal with this frame. - let func_arity = func_signature.results().len() as u32; - stack.push_frame(Frame { - is_polymorphic: false, - end_arity: func_arity, - branch_arity: func_arity, - start_height: 0, - }); - - loop { - if pc >= instructions.elements().len() { - break +/// This is a helper context that is used by [`MaxStackHeightCounter`]. +#[derive(Clone, Copy)] +pub(crate) struct MaxStackHeightCounterContext<'a> { + pub module: &'a elements::Module, + pub func_imports: u32, + pub func_section: &'a elements::FunctionSection, + pub code_section: &'a elements::CodeSection, + pub type_section: &'a elements::TypeSection, +} + +impl<'a> TryFrom<&'a elements::Module> for MaxStackHeightCounterContext<'a> { + type Error = &'static str; + + fn try_from(module: &'a elements::Module) -> Result { + Ok(Self { + module, + func_imports: module + .import_count(elements::ImportCountType::Function) + .try_into() + .map_err(|_| "Can't convert func imports count to u32")?, + func_section: module.function_section().ok_or("No function section")?, + code_section: module.code_section().ok_or("No code section")?, + type_section: module.type_section().ok_or("No type section")?, + }) + } +} + +/// This is a counter for the maximum stack height with the ability to take into account the +/// overhead that is added by the [`instrument_call!`] macro. +pub(crate) struct MaxStackHeightCounter<'a> { + context: MaxStackHeightCounterContext<'a>, + stack: Stack, + max_height: u32, + count_instrumented_calls: bool, +} + +impl<'a> MaxStackHeightCounter<'a> { + /// Creates a [`MaxStackHeightCounter`] from [`MaxStackHeightCounterContext`]. + pub fn new_with_context(context: MaxStackHeightCounterContext<'a>) -> Self { + Self { context, stack: Stack::new(), max_height: 0, count_instrumented_calls: false } + } + + /// Should the overhead of the [`instrument_call!`] macro be taken into account? + pub fn count_instrumented_calls(mut self, count_instrumented_calls: bool) -> Self { + self.count_instrumented_calls = count_instrumented_calls; + self + } + + /// Tries to calculate the maximum stack height for the `func_idx` defined in the wasm module. + pub fn compute_for_defined_func(&mut self, func_idx: u32) -> Result { + let MaxStackHeightCounterContext { func_section, code_section, type_section, .. } = + self.context; + + // Get a signature and a body of the specified function. + let func_sig_idx = func_section + .entries() + .get(func_idx as usize) + .ok_or("Function is not found in func section")? + .type_ref(); + let Type::Function(func_signature) = type_section + .types() + .get(func_sig_idx as usize) + .ok_or("Function is not found in func section")?; + let body = code_section + .bodies() + .get(func_idx as usize) + .ok_or("Function body for the index isn't found")?; + let instructions = body.code(); + + self.compute_for_raw_func(func_signature, instructions.elements()) + } + + /// Tries to calculate the maximum stack height for a raw function, which consists of + /// `func_signature` and `instructions`. + pub fn compute_for_raw_func( + &mut self, + func_signature: &elements::FunctionType, + instructions: &[Instruction], + ) -> Result { + // Add implicit frame for the function. Breaks to this frame and execution of + // the last end should deal with this frame. + let func_arity = func_signature.results().len() as u32; + self.stack.push_frame(Frame { + is_polymorphic: false, + end_arity: func_arity, + branch_arity: func_arity, + start_height: 0, + }); + + for instruction in instructions { + let maybe_instructions = + self.count_instrumented_calls + .then_some(instruction) + .and_then(|inst| match inst { + &Instruction::Call(idx) if idx >= self.context.func_imports => + Some(instrument_call!(idx, 0, 0, 0)), + _ => None, + }); + + if let Some(instructions) = maybe_instructions.as_ref() { + for instruction in instructions { + self.process_instruction(instruction, func_arity)?; + } + } else { + self.process_instruction(instruction, func_arity)?; + } } + Ok(self.max_height) + } + + /// This function processes all incoming instructions and updates the `self.max_height` field. + fn process_instruction( + &mut self, + opcode: &Instruction, + func_arity: u32, + ) -> Result<(), &'static str> { + use Instruction::*; + + let Self { stack, max_height, .. } = self; + let MaxStackHeightCounterContext { module, type_section, .. } = self.context; + // If current value stack is higher than maximal height observed so far, // save the new height. // However, we don't increase maximal value in unreachable code. - if stack.height() > max_height && !stack.frame(0)?.is_polymorphic { - max_height = stack.height(); + if stack.height() > *max_height && !stack.frame(0)?.is_polymorphic { + *max_height = stack.height(); } - let opcode = &instructions.elements()[pc]; - match opcode { Nop => {}, Block(ty) | Loop(ty) | If(ty) => { @@ -403,10 +474,9 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result Result { + MaxStackHeightCounter::new_with_context(module.try_into()?) + .count_instrumented_calls(true) + .compute_for_defined_func(func_idx) + } + fn parse_wat(source: &str) -> elements::Module { elements::deserialize_buffer(&wat::parse_str(source).expect("Failed to wat2wasm")) .expect("Failed to deserialize the module") @@ -437,7 +513,7 @@ mod tests { ); let height = compute(0, &module).unwrap(); - assert_eq!(height, 3 + ACTIVATION_FRAME_COST); + assert_eq!(height, 3); } #[test] @@ -454,7 +530,7 @@ mod tests { ); let height = compute(0, &module).unwrap(); - assert_eq!(height, 1 + ACTIVATION_FRAME_COST); + assert_eq!(height, 1); } #[test] @@ -472,7 +548,7 @@ mod tests { ); let height = compute(0, &module).unwrap(); - assert_eq!(height, ACTIVATION_FRAME_COST); + assert_eq!(height, 0); } #[test] @@ -501,7 +577,7 @@ mod tests { ); let height = compute(0, &module).unwrap(); - assert_eq!(height, 2 + ACTIVATION_FRAME_COST); + assert_eq!(height, 2); } #[test] @@ -525,7 +601,7 @@ mod tests { ); let height = compute(0, &module).unwrap(); - assert_eq!(height, 1 + ACTIVATION_FRAME_COST); + assert_eq!(height, 1); } #[test] @@ -547,7 +623,7 @@ mod tests { ); let height = compute(0, &module).unwrap(); - assert_eq!(height, 1 + ACTIVATION_FRAME_COST); + assert_eq!(height, 1); } #[test] @@ -573,6 +649,6 @@ mod tests { ); let height = compute(0, &module).unwrap(); - assert_eq!(height, 3 + ACTIVATION_FRAME_COST); + assert_eq!(height, 3); } } diff --git a/src/stack_limiter/mod.rs b/src/stack_limiter/mod.rs index 6ee8a750..b166085c 100644 --- a/src/stack_limiter/mod.rs +++ b/src/stack_limiter/mod.rs @@ -2,6 +2,7 @@ use alloc::{vec, vec::Vec}; use core::mem; +use max_height::{MaxStackHeightCounter, MaxStackHeightCounterContext}; use parity_wasm::{ builder, elements::{self, Instruction, Instructions, Type}, @@ -155,16 +156,27 @@ fn generate_stack_height_global(module: &mut elements::Module) -> u32 { /// /// Returns a vector with a stack cost for each function, including imports. fn compute_stack_costs(module: &elements::Module) -> Result, &'static str> { - let func_imports = module.import_count(elements::ImportCountType::Function); + let functions_space = module + .functions_space() + .try_into() + .map_err(|_| "Can't convert functions space to u32")?; + + // Don't create context when there are no functions (this will fail). + if functions_space == 0 { + return Ok(Vec::new()); + } - // TODO: optimize! - (0..module.functions_space()) + // This context already contains the module, number of imports and section references. + // So we can use it to optimize access to these objects. + let context: MaxStackHeightCounterContext = module.try_into()?; + + (0..functions_space) .map(|func_idx| { - if func_idx < func_imports { + if func_idx < context.func_imports { // We can't calculate stack_cost of the import functions. Ok(0) } else { - compute_stack_cost(func_idx as u32, module) + compute_stack_cost(func_idx, context) } }) .collect() @@ -173,17 +185,18 @@ fn compute_stack_costs(module: &elements::Module) -> Result, &'static s /// Stack cost of the given *defined* function is the sum of it's locals count (that is, /// number of arguments plus number of local variables) and the maximal stack /// height. -fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result { +fn compute_stack_cost( + func_idx: u32, + context: MaxStackHeightCounterContext, +) -> Result { // To calculate the cost of a function we need to convert index from // function index space to defined function spaces. - let func_imports = module.import_count(elements::ImportCountType::Function) as u32; let defined_func_idx = func_idx - .checked_sub(func_imports) + .checked_sub(context.func_imports) .ok_or("This should be a index of a defined function")?; - let code_section = - module.code_section().ok_or("Due to validation code section should exists")?; - let body = &code_section + let body = context + .code_section .bodies() .get(defined_func_idx as usize) .ok_or("Function body is out of bounds")?; @@ -194,7 +207,9 @@ fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result>, // Index in function space of this thunk. idx: Option, - callee_stack_cost: u32, } pub fn generate_thunks( @@ -19,36 +22,67 @@ pub fn generate_thunks( ) -> Result { // First, we need to collect all function indices that should be replaced by thunks let mut replacement_map: Map = { - let exports = module.export_section().map(|es| es.entries()).unwrap_or(&[]); - let elem_segments = module.elements_section().map(|es| es.entries()).unwrap_or(&[]); - let start_func_idx = module.start_section(); - - let exported_func_indices = exports.iter().filter_map(|entry| match entry.internal() { - Internal::Function(function_idx) => Some(*function_idx), - _ => None, - }); - let table_func_indices = - elem_segments.iter().flat_map(|segment| segment.members()).cloned(); - // Replacement map is at least export section size. let mut replacement_map: Map = Map::new(); - for func_idx in exported_func_indices - .chain(table_func_indices) - .chain(start_func_idx.into_iter()) + let mut peekable_iter = thunk_function_indexes(&module).peekable(); + let maybe_context: Option = if peekable_iter.peek().is_some() { - let callee_stack_cost = ctx.stack_cost(func_idx).ok_or("function index isn't found")?; + let module_ref = &module; + Some(module_ref.try_into()?) + } else { + None + }; + + for func_idx in peekable_iter { + let mut callee_stack_cost = + ctx.stack_cost(func_idx).ok_or("function index isn't found")?; // Don't generate a thunk if stack_cost of a callee is zero. if callee_stack_cost != 0 { - replacement_map.insert( + let signature = resolve_func_type(func_idx, &module)?.clone(); + + const CALLEE_STACK_COST_PLACEHOLDER: i32 = 1248163264; + let instrumented_call = instrument_call!( func_idx, - Thunk { - signature: resolve_func_type(func_idx, &module)?.clone(), - idx: None, - callee_stack_cost, - }, + CALLEE_STACK_COST_PLACEHOLDER, + ctx.stack_height_global_idx(), + ctx.stack_limit() ); + + // Thunk body consist of: + // - argument pushing + // - instrumented call + // - end + let mut thunk_body: Vec = + Vec::with_capacity(signature.params().len() + instrumented_call.len() + 1); + + for (arg_idx, _) in signature.params().iter().enumerate() { + thunk_body.push(Instruction::GetLocal(arg_idx as u32)); + } + thunk_body.extend_from_slice(&instrumented_call); + thunk_body.push(Instruction::End); + + // Update callee_stack_cost to charge for the thunk call itself + let context = + maybe_context.expect("MaxStackHeightCounterContext must be initialized"); + let thunk_cost = MaxStackHeightCounter::new_with_context(context) + .compute_for_raw_func(&signature, &thunk_body)?; + + callee_stack_cost = callee_stack_cost + .checked_add(thunk_cost) + .ok_or("overflow during callee_stack_cost calculation")?; + + // Update thunk body with new cost + for instruction in thunk_body + .iter_mut() + .filter(|i| **i == Instruction::I32Const(CALLEE_STACK_COST_PLACEHOLDER)) + { + *instruction = Instruction::I32Const(callee_stack_cost as i32); + } + + replacement_map + .insert(func_idx, Thunk { signature, body: Some(thunk_body), idx: None }); } } @@ -61,28 +95,11 @@ pub fn generate_thunks( let mut next_func_idx = module.functions_space() as u32; let mut mbuilder = builder::from_module(module); - for (func_idx, thunk) in replacement_map.iter_mut() { - let instrumented_call = instrument_call!( - *func_idx, - thunk.callee_stack_cost as i32, - ctx.stack_height_global_idx(), - ctx.stack_limit() - ); - // Thunk body consist of: - // - argument pushing - // - instrumented call - // - end - let mut thunk_body: Vec = - Vec::with_capacity(thunk.signature.params().len() + instrumented_call.len() + 1); - - for (arg_idx, _) in thunk.signature.params().iter().enumerate() { - thunk_body.push(elements::Instruction::GetLocal(arg_idx as u32)); - } - thunk_body.extend_from_slice(&instrumented_call); - thunk_body.push(elements::Instruction::End); - + for thunk in replacement_map.values_mut() { // TODO: Don't generate a signature, but find an existing one. + let thunk_body = thunk.body.take().expect("can't get thunk function body"); + mbuilder = mbuilder .function() // Signature of the thunk should match the original function signature. @@ -91,7 +108,7 @@ pub fn generate_thunks( .with_results(thunk.signature.results().to_vec()) .build() .body() - .with_instructions(elements::Instructions::new(thunk_body)) + .with_instructions(Instructions::new(thunk_body)) .build() .build(); @@ -133,3 +150,17 @@ pub fn generate_thunks( Ok(module) } + +fn thunk_function_indexes(module: &elements::Module) -> impl Iterator + '_ { + let exports = module.export_section().map(|es| es.entries()).unwrap_or(&[]); + let elem_segments = module.elements_section().map(|es| es.entries()).unwrap_or(&[]); + let start_func_idx = module.start_section(); + + let exported_func_indices = exports.iter().filter_map(|entry| match entry.internal() { + Internal::Function(function_idx) => Some(*function_idx), + _ => None, + }); + let table_func_indices = elem_segments.iter().flat_map(|segment| segment.members()).cloned(); + + exported_func_indices.chain(table_func_indices).chain(start_func_idx) +} diff --git a/tests/expectations/stack-height/empty_functions.wat b/tests/expectations/stack-height/empty_functions.wat index 8a13ffe5..293c5eab 100644 --- a/tests/expectations/stack-height/empty_functions.wat +++ b/tests/expectations/stack-height/empty_functions.wat @@ -36,7 +36,7 @@ ) (func (;2;) (type 0) global.get 0 - i32.const 2 + i32.const 4 i32.add global.set 0 global.get 0 @@ -47,7 +47,7 @@ end call 1 global.get 0 - i32.const 2 + i32.const 4 i32.sub global.set 0 ) diff --git a/tests/expectations/stack-height/global.wat b/tests/expectations/stack-height/global.wat index 4558b316..633346a4 100644 --- a/tests/expectations/stack-height/global.wat +++ b/tests/expectations/stack-height/global.wat @@ -18,7 +18,7 @@ local.get $tmp local.get $arg global.get 1 - i32.const 4 + i32.const 2 i32.add global.set 1 global.get 1 @@ -29,7 +29,7 @@ end call $i32.add global.get 1 - i32.const 4 + i32.const 2 i32.sub global.set 1 drop @@ -38,7 +38,7 @@ local.get 0 local.get 1 global.get 1 - i32.const 4 + i32.const 6 i32.add global.set 1 global.get 1 @@ -49,7 +49,7 @@ end call $i32.add global.get 1 - i32.const 4 + i32.const 6 i32.sub global.set 1 ) diff --git a/tests/expectations/stack-height/imports.wat b/tests/expectations/stack-height/imports.wat index 9b437190..b2c0e690 100644 --- a/tests/expectations/stack-height/imports.wat +++ b/tests/expectations/stack-height/imports.wat @@ -14,7 +14,7 @@ local.get 0 local.get 1 global.get 0 - i32.const 4 + i32.const 6 i32.add global.set 0 global.get 0 @@ -25,7 +25,7 @@ end call 2 global.get 0 - i32.const 4 + i32.const 6 i32.sub global.set 0 ) diff --git a/tests/expectations/stack-height/many_locals.wat b/tests/expectations/stack-height/many_locals.wat index a6801b37..1763b9cd 100644 --- a/tests/expectations/stack-height/many_locals.wat +++ b/tests/expectations/stack-height/many_locals.wat @@ -5,7 +5,7 @@ ) (func $main (;1;) (type 0) global.get 0 - i32.const 5 + i32.const 3 i32.add global.set 0 global.get 0 @@ -16,7 +16,7 @@ end call $one-group-many-locals global.get 0 - i32.const 5 + i32.const 3 i32.sub global.set 0 ) diff --git a/tests/expectations/stack-height/start.wat b/tests/expectations/stack-height/start.wat index 25988e3e..077fb3ea 100644 --- a/tests/expectations/stack-height/start.wat +++ b/tests/expectations/stack-height/start.wat @@ -24,24 +24,7 @@ i32.sub global.set 0 ) - (func (;4;) (type 1) - global.get 0 - i32.const 2 - i32.add - global.set 0 - global.get 0 - i32.const 1024 - i32.gt_u - if ;; label = @1 - unreachable - end - call 2 - global.get 0 - i32.const 2 - i32.sub - global.set 0 - ) (global (;0;) (mut i32) i32.const 0) - (export "call" (func 4)) + (export "call" (func 2)) (start 3) ) \ No newline at end of file diff --git a/tests/expectations/stack-height/table.wat b/tests/expectations/stack-height/table.wat index 12cf982a..2375e7a0 100644 --- a/tests/expectations/stack-height/table.wat +++ b/tests/expectations/stack-height/table.wat @@ -7,7 +7,7 @@ local.get 0 i32.const 0 global.get 0 - i32.const 4 + i32.const 2 i32.add global.set 0 global.get 0 @@ -18,7 +18,7 @@ end call $i32.add global.get 0 - i32.const 4 + i32.const 2 i32.sub global.set 0 drop @@ -31,7 +31,7 @@ (func (;3;) (type 1) (param i32) local.get 0 global.get 0 - i32.const 4 + i32.const 7 i32.add global.set 0 global.get 0 @@ -42,7 +42,7 @@ end call 1 global.get 0 - i32.const 4 + i32.const 7 i32.sub global.set 0 ) @@ -50,7 +50,7 @@ local.get 0 local.get 1 global.get 0 - i32.const 4 + i32.const 6 i32.add global.set 0 global.get 0 @@ -61,7 +61,7 @@ end call $i32.add global.get 0 - i32.const 4 + i32.const 6 i32.sub global.set 0 )