diff --git a/changelog/pmikolajczyk-nit-4364.md b/changelog/pmikolajczyk-nit-4364.md new file mode 100644 index 0000000000..aa12cdef72 --- /dev/null +++ b/changelog/pmikolajczyk-nit-4364.md @@ -0,0 +1,3 @@ +### Internal + - Move the server side of the validation communication protocol from `jit` to `validation` crate. + - Add client side implementation. Add tests. diff --git a/crates/jit/src/lib.rs b/crates/jit/src/lib.rs index 31401724b9..1f4976061f 100644 --- a/crates/jit/src/lib.rs +++ b/crates/jit/src/lib.rs @@ -18,7 +18,6 @@ mod caller_env; pub mod machine; mod prepare; pub mod program; -pub mod socket; pub mod stylus_backend; mod test; mod wasip1_stub; diff --git a/crates/jit/src/main.rs b/crates/jit/src/main.rs index 60c12c7d79..7294897011 100644 --- a/crates/jit/src/main.rs +++ b/crates/jit/src/main.rs @@ -4,12 +4,8 @@ use arbutil::Color; use clap::Parser; use eyre::Result; -use jit::{ - machine::Escape, - run, - socket::{report_error, report_success}, - Opts, -}; +use jit::{machine::Escape, run, Opts}; +use validation::transfer::{send_failure_response, send_successful_response}; use wasmer::FrameInfo; fn main() -> Result<()> { @@ -34,7 +30,7 @@ fn main() -> Result<()> { println!("{message}") } if let Some(mut socket) = result.socket { - report_error(&mut socket, message); + send_failure_response(&mut socket, &message)?; } if opts.validator.require_success { std::process::exit(1); @@ -47,7 +43,11 @@ fn main() -> Result<()> { ) } if let Some(mut socket) = result.socket { - report_success(&mut socket, &result.new_state, &result.memory_used); + send_successful_response( + &mut socket, + &result.new_state.into(), + result.memory_used.bytes().0 as u64, + )?; } } Ok(()) diff --git a/crates/jit/src/prepare.rs b/crates/jit/src/prepare.rs index 25ca882878..e116c4c484 100644 --- a/crates/jit/src/prepare.rs +++ b/crates/jit/src/prepare.rs @@ -3,26 +3,10 @@ use crate::machine::WasmEnv; use eyre::Ok; -use std::env; use std::fs::File; use std::io::BufReader; use std::path::Path; -use validation::ValidationInput; - -// local_target matches rawdb.LocalTarget() on the go side. -// While generating json_inputs file, one should make sure user_wasms map -// has entry for the system's arch that jit validation is being run on -pub fn local_target() -> String { - if env::consts::OS == "linux" { - match env::consts::ARCH { - "aarch64" => "arm64".to_string(), - "x86_64" => "amd64".to_string(), - _ => "host".to_string(), - } - } else { - "host".to_string() - } -} +use validation::{local_target, ValidationInput}; pub fn prepare_env_from_json(json_inputs: &Path, debug: bool) -> eyre::Result { let file = File::open(json_inputs)?; @@ -54,7 +38,7 @@ pub fn prepare_env_from_json(json_inputs: &Path, debug: bool) -> eyre::Result(reader: &mut BufReader) -> Result { - let mut buf = [0; 1]; - reader.read_exact(&mut buf).map(|_| u8::from_be_bytes(buf)) -} - -pub fn read_u32(reader: &mut BufReader) -> Result { - let mut buf = [0; 4]; - reader.read_exact(&mut buf).map(|_| u32::from_be_bytes(buf)) -} - -pub fn read_u64(reader: &mut BufReader) -> Result { - let mut buf = [0; 8]; - reader.read_exact(&mut buf).map(|_| u64::from_be_bytes(buf)) -} - -pub fn read_bytes32(reader: &mut BufReader) -> Result { - let mut buf = [0u8; 32]; - reader.read_exact(&mut buf).map(|_| buf.into()) -} - -pub fn read_bytes(reader: &mut BufReader) -> Result, io::Error> { - let size = read_u64(reader)?; - let mut buf = vec![0; size as usize]; - reader.read_exact(&mut buf)?; - Ok(buf) -} - -pub fn read_boxed_slice(reader: &mut BufReader) -> Result, io::Error> { - Ok(Vec::into_boxed_slice(read_bytes(reader)?)) -} - -pub fn write_u8(writer: &mut BufWriter, data: u8) -> Result<(), io::Error> { - let buf = [data; 1]; - writer.write_all(&buf) -} - -pub fn write_u64(writer: &mut BufWriter, data: u64) -> Result<(), io::Error> { - let buf = data.to_be_bytes(); - writer.write_all(&buf) -} - -pub fn write_bytes32(writer: &mut BufWriter, data: &Bytes32) -> Result<(), io::Error> { - writer.write_all(data.as_slice()) -} - -pub fn write_bytes(writer: &mut BufWriter, data: &[u8]) -> Result<(), io::Error> { - write_u64(writer, data.len() as u64)?; - writer.write_all(data) -} - -macro_rules! check { - ($expr:expr) => {{ - if let Err(comms_error) = $expr { - eprintln!("Failed to send results to Go: {comms_error}"); - panic!("Communication failure"); - } - }}; -} - -pub fn report_success( - writer: &mut BufWriter, - new_state: &GlobalState, - memory_used: &Pages, -) { - check!(write_u8(writer, SUCCESS)); - check!(write_u64(writer, new_state.inbox_position)); - check!(write_u64(writer, new_state.position_within_message)); - check!(write_bytes32(writer, &new_state.last_block_hash)); - check!(write_bytes32(writer, &new_state.last_send_root)); - check!(write_u64(writer, memory_used.bytes().0 as u64)); - check!(writer.flush()); -} - -pub fn report_error(writer: &mut BufWriter, error: String) { - check!(write_u8(writer, FAILURE)); - check!(write_bytes(writer, &error.into_bytes())); - check!(writer.flush()); -} diff --git a/crates/jit/src/wavmio.rs b/crates/jit/src/wavmio.rs index 5c867c777c..d40a6d15f7 100644 --- a/crates/jit/src/wavmio.rs +++ b/crates/jit/src/wavmio.rs @@ -4,7 +4,6 @@ use crate::{ caller_env::JitEnv, machine::{Escape, MaybeEscape, WasmEnv, WasmEnvMut}, - socket, }; use arbutil::{Color, PreimageType}; use caller_env::{GuestPtr, MemAccess}; @@ -14,6 +13,8 @@ use std::{ net::TcpStream, time::Instant, }; +use validation::local_target; +use validation::transfer::receive_validation_input; /// Reads 32-bytes of global state. pub fn get_global_state_bytes32(mut env: WasmEnvMut, idx: u32, out_ptr: GuestPtr) -> MaybeEscape { @@ -281,49 +282,27 @@ fn ready_hostio(env: &mut WasmEnv) -> MaybeEscape { socket.set_nodelay(true)?; let mut reader = BufReader::new(socket.try_clone()?); - let stream = &mut reader; + let input = receive_validation_input(&mut reader)?; - let inbox_position = socket::read_u64(stream)?; - let position_within_message = socket::read_u64(stream)?; - let last_block_hash = socket::read_bytes32(stream)?; - let last_send_root = socket::read_bytes32(stream)?; + env.small_globals = [input.start_state.batch, input.start_state.pos_in_batch]; + env.large_globals = [input.start_state.block_hash, input.start_state.send_root]; - env.small_globals = [inbox_position, position_within_message]; - env.large_globals = [last_block_hash, last_send_root]; - - while socket::read_u8(stream)? == socket::ANOTHER { - let position = socket::read_u64(stream)?; - let message = socket::read_bytes(stream)?; - env.sequencer_messages.insert(position, message); + for batch in input.batch_info { + env.sequencer_messages.insert(batch.number, batch.data); } - while socket::read_u8(stream)? == socket::ANOTHER { - let position = socket::read_u64(stream)?; - let message = socket::read_bytes(stream)?; - env.delayed_messages.insert(position, message); + if input.has_delayed_msg { + env.delayed_messages + .insert(input.delayed_msg_nr, input.delayed_msg); } - - let preimage_types = socket::read_u32(stream)?; - for _ in 0..preimage_types { - let preimage_ty = PreimageType::try_from(socket::read_u8(stream)?) - .map_err(|e| Escape::Failure(e.to_string()))?; - let map = env.preimages.entry(preimage_ty).or_default(); - let preimage_count = socket::read_u32(stream)?; - for _ in 0..preimage_count { - let hash = socket::read_bytes32(stream)?; - let preimage = socket::read_bytes(stream)?; - map.insert(hash, preimage); + for (preimage_type, preimages) in input.preimages { + let preimage_map = env.preimages.entry(preimage_type).or_default(); + for (hash, preimage) in preimages { + preimage_map.insert(hash, preimage); } } - - let programs_count = socket::read_u32(stream)?; - for _ in 0..programs_count { - let module_hash = socket::read_bytes32(stream)?; - let module_asm = socket::read_boxed_slice(stream)?; - env.module_asms.insert(module_hash, module_asm.into()); - } - - if socket::read_u8(stream)? != socket::READY { - return Escape::hostio("failed to parse global state"); + for (module_hash, module_asm) in &input.user_wasms[local_target()] { + env.module_asms + .insert(*module_hash, module_asm.as_vec().into()); } let writer = BufWriter::new(socket); diff --git a/crates/validation/src/lib.rs b/crates/validation/src/lib.rs index 8df5707ed6..dec2d185e6 100644 --- a/crates/validation/src/lib.rs +++ b/crates/validation/src/lib.rs @@ -7,8 +7,27 @@ use std::{ io::{self, BufRead}, }; +pub mod transfer; + +pub type PreimageMap = HashMap>>; + +pub const TARGET_ARM_64: &str = "arm64"; +pub const TARGET_AMD_64: &str = "amd64"; +pub const TARGET_HOST: &str = "host"; + +/// Counterpart to Go `rawdb.LocalTarget()`. +pub fn local_target() -> &'static str { + if cfg!(all(target_os = "linux", target_arch = "aarch64")) { + TARGET_ARM_64 + } else if cfg!(all(target_os = "linux", target_arch = "x86_64")) { + TARGET_AMD_64 + } else { + TARGET_HOST + } +} + /// Counterpart to Go `validator.GoGlobalState`. -#[derive(Clone, Debug, Serialize, Deserialize, Default)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "PascalCase")] pub struct GoGlobalState { #[serde(with = "As::")] @@ -20,7 +39,7 @@ pub struct GoGlobalState { } /// Counterpart to Go `validator.server_api.BatchInfoJson`. -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)] #[serde(rename_all = "PascalCase")] pub struct BatchInfo { pub number: u64, @@ -34,7 +53,7 @@ pub struct BatchInfo { /// /// Note: The wrapped `Vec` is already `Base64` decoded before /// `from(Vec)` is called by `serde`. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct UserWasm(Vec); impl UserWasm { @@ -60,7 +79,7 @@ impl TryFrom> for UserWasm { } /// Counterpart to Go `validator.server_api.InputJSON`. -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize)] #[serde(rename_all = "PascalCase")] pub struct ValidationInput { pub id: u64, @@ -70,7 +89,7 @@ pub struct ValidationInput { rename = "PreimagesB64", with = "As::, HashMap>>" )] - pub preimages: HashMap>>, + pub preimages: PreimageMap, pub batch_info: Vec, #[serde(rename = "DelayedMsgB64", with = "As::")] pub delayed_msg: Vec, @@ -86,4 +105,11 @@ impl ValidationInput { pub fn from_reader(mut reader: R) -> io::Result { Ok(serde_json::from_reader(&mut reader)?) } + + pub fn delayed_msg(&self) -> Option { + self.has_delayed_msg.then(|| BatchInfo { + number: self.delayed_msg_nr, + data: self.delayed_msg.clone(), + }) + } } diff --git a/crates/validation/src/transfer/markers.rs b/crates/validation/src/transfer/markers.rs new file mode 100644 index 0000000000..781fbbd08c --- /dev/null +++ b/crates/validation/src/transfer/markers.rs @@ -0,0 +1,5 @@ +pub const SUCCESS: u8 = 0x0; +pub const FAILURE: u8 = 0x1; +// pub const PREIMAGE: u8 = 0x2; // legacy, not used +pub const ANOTHER: u8 = 0x3; +pub const READY: u8 = 0x4; diff --git a/crates/validation/src/transfer/mod.rs b/crates/validation/src/transfer/mod.rs new file mode 100644 index 0000000000..d2b71d830c --- /dev/null +++ b/crates/validation/src/transfer/mod.rs @@ -0,0 +1,13 @@ +use std::io; + +mod markers; +mod primitives; +mod receiver; +mod sender; +#[cfg(test)] +mod tests; + +pub use receiver::*; +pub use sender::*; + +pub type IOResult = Result; diff --git a/crates/validation/src/transfer/primitives.rs b/crates/validation/src/transfer/primitives.rs new file mode 100644 index 0000000000..74fc033076 --- /dev/null +++ b/crates/validation/src/transfer/primitives.rs @@ -0,0 +1,54 @@ +use crate::transfer::IOResult; +use arbutil::Bytes32; +use std::io::{Read, Write}; + +pub fn read_u8(reader: &mut impl Read) -> IOResult { + let mut buf = [0; 1]; + reader.read_exact(&mut buf).map(|_| u8::from_be_bytes(buf)) +} + +pub fn write_u8(writer: &mut impl Write, data: u8) -> IOResult<()> { + let buf = [data; 1]; + writer.write_all(&buf) +} + +pub fn read_u32(reader: &mut impl Read) -> IOResult { + let mut buf = [0; 4]; + reader.read_exact(&mut buf).map(|_| u32::from_be_bytes(buf)) +} + +pub fn write_u32(writer: &mut impl Write, data: u32) -> IOResult<()> { + let buf = data.to_be_bytes(); + writer.write_all(&buf) +} + +pub fn read_u64(reader: &mut impl Read) -> IOResult { + let mut buf = [0; 8]; + reader.read_exact(&mut buf).map(|_| u64::from_be_bytes(buf)) +} + +pub fn write_u64(writer: &mut impl Write, data: u64) -> IOResult<()> { + let buf = data.to_be_bytes(); + writer.write_all(&buf) +} + +pub fn read_bytes32(reader: &mut impl Read) -> IOResult { + let mut buf = [0u8; 32]; + reader.read_exact(&mut buf).map(|_| buf.into()) +} + +pub fn write_bytes32(writer: &mut impl Write, data: &Bytes32) -> IOResult<()> { + writer.write_all(data.as_slice()) +} + +pub fn read_bytes(reader: &mut impl Read) -> IOResult> { + let size = read_u64(reader)?; + let mut buf = vec![0; size as usize]; + reader.read_exact(&mut buf)?; + Ok(buf) +} + +pub fn write_bytes(writer: &mut impl Write, data: &[u8]) -> IOResult<()> { + write_u64(writer, data.len() as u64)?; + writer.write_all(data) +} diff --git a/crates/validation/src/transfer/receiver.rs b/crates/validation/src/transfer/receiver.rs new file mode 100644 index 0000000000..313ab2b490 --- /dev/null +++ b/crates/validation/src/transfer/receiver.rs @@ -0,0 +1,116 @@ +use crate::transfer::primitives::{read_bytes, read_bytes32, read_u32, read_u64, read_u8}; +use crate::transfer::{markers, IOResult}; +use crate::{local_target, BatchInfo, GoGlobalState, PreimageMap, UserWasm, ValidationInput}; +use arbutil::{Bytes32, PreimageType}; +use io::Error; +use std::collections::HashMap; +use std::io; +use std::io::ErrorKind::InvalidData; +use std::io::Read; + +pub fn receive_validation_input(reader: &mut impl Read) -> IOResult { + let start_state = receive_global_state(reader)?; + let inbox = receive_batches(reader)?; + let delayed_message = receive_delayed_message(reader)?.unwrap_or_default(); + let preimages = receive_preimages(reader)?; + let user_wasms = receive_user_wasms(reader)?; + ensure_readiness(reader)?; + + Ok(ValidationInput { + has_delayed_msg: !delayed_message.data.is_empty(), + delayed_msg_nr: delayed_message.number, + preimages, + batch_info: inbox, + delayed_msg: delayed_message.data, + start_state, + user_wasms: HashMap::from([(local_target().to_string(), user_wasms)]), + ..Default::default() + }) +} + +pub fn receive_response(reader: &mut impl Read) -> IOResult> { + match read_u8(reader)? { + markers::SUCCESS => { + let new_state = receive_global_state(reader)?; + let memory_used = read_u64(reader)?; + Ok(Ok((new_state, memory_used))) + } + markers::FAILURE => { + let error_bytes = read_bytes(reader)?; + let error_message = String::from_utf8_lossy(&error_bytes).to_string(); + Ok(Err(error_message)) + } + other => Ok(Err(format!("unexpected response byte: {other}"))), + } +} + +fn receive_global_state(reader: &mut impl Read) -> IOResult { + let inbox_position = read_u64(reader)?; + let position_within_message = read_u64(reader)?; + let last_block_hash = read_bytes32(reader)?; + let last_send_root = read_bytes32(reader)?; + Ok(GoGlobalState { + block_hash: last_block_hash, + send_root: last_send_root, + batch: inbox_position, + pos_in_batch: position_within_message, + }) +} + +fn receive_batches(reader: &mut impl Read) -> IOResult> { + let mut batches = vec![]; + while read_u8(reader)? == markers::ANOTHER { + let number = read_u64(reader)?; + let data = read_bytes(reader)?; + batches.push(BatchInfo { number, data }); + } + Ok(batches) +} + +fn receive_delayed_message(reader: &mut impl Read) -> IOResult> { + match &receive_batches(reader)?[..] { + [] => Ok(None), + [batch_info] => Ok(Some(batch_info.clone())), + _ => Err(Error::new(InvalidData, "multiple delayed batches")), + } +} + +fn receive_preimages(reader: &mut impl Read) -> IOResult { + let preimage_types = read_u32(reader)?; + let mut preimages = PreimageMap::with_capacity(preimage_types as usize); + for _ in 0..preimage_types { + let preimage_ty = PreimageType::try_from(read_u8(reader)?) + .map_err(|e| Error::new(InvalidData, e.to_string()))?; + let map = preimages.entry(preimage_ty).or_default(); + let preimage_count = read_u32(reader)?; + for _ in 0..preimage_count { + let hash = read_bytes32(reader)?; + let preimage = read_bytes(reader)?; + map.insert(hash, preimage); + } + } + Ok(preimages) +} + +fn receive_user_wasms(reader: &mut impl Read) -> IOResult> { + let programs_count = read_u32(reader)?; + let mut user_wasms = HashMap::with_capacity(programs_count as usize); + for _ in 0..programs_count { + let module_hash = read_bytes32(reader)?; + let module_asm = read_bytes(reader)?; + user_wasms.insert(module_hash, UserWasm(module_asm)); + } + Ok(user_wasms) +} + +fn ensure_readiness(reader: &mut impl Read) -> IOResult<()> { + let byte = read_u8(reader)?; + if byte == markers::READY { + Ok(()) + } else { + Err(Error::new( + InvalidData, + format!("expected READY byte, got {byte}"), + )) + } +} diff --git a/crates/validation/src/transfer/sender.rs b/crates/validation/src/transfer/sender.rs new file mode 100644 index 0000000000..c4d49ef398 --- /dev/null +++ b/crates/validation/src/transfer/sender.rs @@ -0,0 +1,96 @@ +use crate::transfer::primitives::{write_bytes, write_bytes32, write_u32, write_u64, write_u8}; +use crate::transfer::{markers, IOResult}; +use crate::{local_target, BatchInfo, GoGlobalState, PreimageMap, UserWasm, ValidationInput}; +use arbutil::Bytes32; +use std::collections::HashMap; +use std::io::ErrorKind::InvalidData; +use std::io::{Error, Write}; + +pub fn send_validation_input(writer: &mut impl Write, input: &ValidationInput) -> IOResult<()> { + send_global_state(writer, &input.start_state)?; + send_batches(writer, &input.batch_info)?; + if let Some(batch) = input.delayed_msg() { + send_batches(writer, &[batch])?; + } + send_preimages(writer, &input.preimages)?; + send_user_wasms(writer, &input.user_wasms)?; + finish_sending(writer) +} + +pub fn send_successful_response( + writer: &mut impl Write, + new_state: &GoGlobalState, + memory_used: u64, +) -> IOResult<()> { + write_u8(writer, markers::SUCCESS)?; + send_global_state(writer, new_state)?; + write_u64(writer, memory_used) +} + +pub fn send_failure_response(writer: &mut impl Write, error_message: &str) -> IOResult<()> { + write_u8(writer, markers::FAILURE)?; + write_bytes(writer, error_message.as_bytes()) +} + +fn send_global_state(writer: &mut impl Write, start_state: &GoGlobalState) -> IOResult<()> { + write_u64(writer, start_state.batch)?; + write_u64(writer, start_state.pos_in_batch)?; + write_bytes32(writer, &start_state.block_hash)?; + write_bytes32(writer, &start_state.send_root) +} + +fn send_batches(writer: &mut impl Write, batch_info: &[BatchInfo]) -> IOResult<()> { + for batch in batch_info { + write_u8(writer, markers::ANOTHER)?; + write_u64(writer, batch.number)?; + write_bytes(writer, &batch.data)?; + } + write_u8(writer, markers::SUCCESS) +} + +fn send_preimages(writer: &mut impl Write, preimages: &PreimageMap) -> IOResult<()> { + write_u32(writer, preimages.len() as u32)?; + for (preimage_type, preimage_map) in preimages { + write_u8(writer, *preimage_type as u8)?; + write_u32(writer, preimage_map.len() as u32)?; + for (hash, preimage) in preimage_map { + write_bytes32(writer, hash)?; + write_bytes(writer, preimage)?; + } + } + Ok(()) +} + +fn send_user_wasms( + writer: &mut impl Write, + user_wasms: &HashMap>, +) -> IOResult<()> { + let local_target = local_target(); + let local_target_user_wasms = user_wasms.get(local_target); + + if local_target_user_wasms.is_none_or(|m| m.is_empty()) { + for (arch, wasms) in user_wasms { + if !wasms.is_empty() { + return Err(Error::new( + InvalidData, + format!("bad stylus arch. got {arch}, expected {local_target}"), + )); + } + } + } + + let Some(local_target_user_wasms) = local_target_user_wasms else { + return Ok(()); + }; + + write_u32(writer, local_target_user_wasms.len() as u32)?; + for (hash, wasm) in local_target_user_wasms { + write_bytes32(writer, hash)?; + write_bytes(writer, wasm.as_ref())?; + } + Ok(()) +} + +fn finish_sending(writer: &mut impl Write) -> IOResult<()> { + write_u8(writer, markers::READY) +} diff --git a/crates/validation/src/transfer/tests.rs b/crates/validation/src/transfer/tests.rs new file mode 100644 index 0000000000..1059ba7c30 --- /dev/null +++ b/crates/validation/src/transfer/tests.rs @@ -0,0 +1,119 @@ +use crate::transfer::{ + receive_response, receive_validation_input, send_failure_response, send_successful_response, + send_validation_input, +}; +use crate::{local_target, BatchInfo, GoGlobalState, UserWasm, ValidationInput}; +use arbutil::{Bytes32, PreimageType}; +use std::collections::HashMap; +use std::io::pipe; + +#[test] +fn transfer_successful_response() -> Result<(), Box> { + let new_state = GoGlobalState { + block_hash: Bytes32::from([1u8; 32]), + send_root: Bytes32::from([2u8; 32]), + batch: 42, + pos_in_batch: 7, + }; + let memory_used = 123456u64; + + let (mut reader, mut writer) = pipe()?; + + send_successful_response(&mut writer, &new_state, memory_used)?; + let (received_state, received_memory) = receive_response(&mut reader)??; + + assert_eq!(received_state, new_state); + assert_eq!(received_memory, memory_used); + Ok(()) +} + +#[test] +fn transfer_failure_response() -> Result<(), Box> { + let error_message = "Validation failed due to some error."; + + let (mut reader, mut writer) = pipe()?; + + send_failure_response(&mut writer, error_message)?; + let result = receive_response(&mut reader)?; + + match result { + Err(err_msg) => assert_eq!(err_msg, error_message), + Ok(_) => panic!("Expected failure response, but got success."), + } + Ok(()) +} + +#[test] +fn transfer_input() -> Result<(), Box> { + let input = ValidationInput { + start_state: Default::default(), + + batch_info: vec![ + BatchInfo { + number: 10, + data: vec![1, 2, 3], + }, + BatchInfo { + number: 11, + data: vec![4, 5, 6], + }, + BatchInfo { + number: 12, + data: vec![7, 8], + }, + ], + + has_delayed_msg: true, + delayed_msg_nr: 1, + delayed_msg: vec![0xAA, 0xBB, 0xCC], + + preimages: HashMap::from([ + ( + PreimageType::Keccak256, + HashMap::from([ + (Bytes32::from([0u8; 32]), vec![0xDE, 0xAD, 0xBE, 0xEF]), + (Bytes32::from([1u8; 32]), vec![0xBA, 0xAD, 0xF0, 0x0D]), + ]), + ), + ( + PreimageType::DACertificate, + HashMap::from([(Bytes32::from([2u8; 32]), vec![0xFE, 0xED, 0xFA, 0xCE])]), + ), + ]), + user_wasms: HashMap::from([( + local_target().to_string(), + HashMap::from([ + (Bytes32::from([3u8; 32]), UserWasm(vec![20, 21, 22])), + (Bytes32::from([4u8; 32]), UserWasm(vec![30, 31, 32])), + ]), + )]), + + ..Default::default() + }; + + let (mut reader, mut writer) = pipe()?; + + send_validation_input(&mut writer, &input)?; + let received_input = receive_validation_input(&mut reader)?; + + assert_eq!(received_input, input); + + Ok(()) +} + +#[test] +fn local_stylus_target_must_be_present_if_some_target_is_present() { + let input = ValidationInput { + user_wasms: HashMap::from([( + "some-other-target".to_string(), + HashMap::from([(Bytes32::from([0u8; 32]), UserWasm(vec![1, 2, 3]))]), + )]), + ..Default::default() + }; + + let (_, mut writer) = pipe().unwrap(); + + let result = send_validation_input(&mut writer, &input); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("bad stylus arch")); +} diff --git a/crates/validator/src/engine/machine.rs b/crates/validator/src/engine/machine.rs index a517fe1f78..e7f7597621 100644 --- a/crates/validator/src/engine/machine.rs +++ b/crates/validator/src/engine/machine.rs @@ -19,64 +19,24 @@ //! This TCP stream is then used for data transfer of the `ValidationRequest` and //! the resulting `GlobalState`. +use crate::engine::config::JitMachineConfig; use anyhow::{anyhow, Context, Result}; use arbutil::Bytes32; +use std::net::TcpListener; use std::{ - collections::HashMap, env::{self}, path::{Path, PathBuf}, process::Stdio, }; +use tokio::io::AsyncWriteExt; use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWriteExt}, - net::{TcpListener, TcpStream}, process::{Child, ChildStdin, Command}, sync::Mutex, }; use tracing::{error, warn}; +use validation::transfer::{receive_response, send_validation_input}; use validation::{GoGlobalState, ValidationInput}; -use crate::{engine::config::JitMachineConfig, spawner_endpoints::local_target}; - -const SUCCESS_BYTE: u8 = 0x0; -const FAILURE_BYTE: u8 = 0x1; -const ANOTHER_BYTE: u8 = 0x3; -const READY_BYTE: u8 = 0x4; - -async fn write_exact(conn: &mut TcpStream, data: &[u8]) -> Result<()> { - conn.write_all(data).await.map_err(|e| anyhow!(e)) -} - -async fn write_u8(conn: &mut TcpStream, data: u8) -> Result<()> { - write_exact(conn, &[data]).await -} - -async fn write_u32(conn: &mut TcpStream, data: u32) -> Result<()> { - write_exact(conn, &data.to_be_bytes()).await -} - -async fn write_u64(conn: &mut TcpStream, data: u64) -> Result<()> { - write_exact(conn, &data.to_be_bytes()).await -} - -async fn write_bytes(conn: &mut TcpStream, data: &[u8]) -> Result<()> { - write_u64(conn, data.len() as u64).await?; - write_exact(conn, data).await -} - -async fn read_bytes32(reader: &mut R) -> Result<[u8; 32]> { - let mut buf = [0u8; 32]; - reader.read_exact(&mut buf).await?; - Ok(buf) -} - -async fn read_bytes_with_len(reader: &mut R) -> Result> { - let len = reader.read_u64().await?; - let mut buf = vec![0u8; len as usize]; - reader.read_exact(&mut buf).await?; - Ok(buf) -} - #[derive(Debug)] pub struct JitMachine { /// Handler to jit binary stdin. Instead of using Mutex<> for the entire @@ -150,11 +110,7 @@ impl JitMachine { pub async fn feed_machine(&self, request: &ValidationInput) -> Result { // 1. Create new TCP connection // Binding with a port number of 0 will request that the OS assigns a port to this listener. - let listener = TcpListener::bind("127.0.0.1:0") - .await - .context("failed to create TCP listener")?; - - let mut state = GoGlobalState::default(); + let listener = TcpListener::bind("127.0.0.1:0").context("failed to create TCP listener")?; let addr = listener.local_addr().context("failed to get local addr")?; @@ -177,101 +133,26 @@ impl JitMachine { // 4. Wait for the child to call us back let (mut conn, _) = listener .accept() - .await .context("failed to open listener connection")?; - // 5. Send Global State - // TODO: add timeout for reads and writes - write_u64(&mut conn, request.start_state.batch).await?; - write_u64(&mut conn, request.start_state.pos_in_batch).await?; - write_exact(&mut conn, &request.start_state.block_hash.0).await?; - write_exact(&mut conn, &request.start_state.send_root.0).await?; - - // 6. Send batch info - for batch in request.batch_info.iter() { - write_u8(&mut conn, ANOTHER_BYTE).await?; - write_u64(&mut conn, batch.number).await?; - write_bytes(&mut conn, &batch.data).await?; - } - write_u8(&mut conn, SUCCESS_BYTE).await?; - - // 7. Send Delayed Inbox - if request.has_delayed_msg { - write_u8(&mut conn, ANOTHER_BYTE).await?; - write_u64(&mut conn, request.delayed_msg_nr).await?; - write_bytes(&mut conn, &request.delayed_msg).await?; - } - write_u8(&mut conn, SUCCESS_BYTE).await?; - - // 8. Send Known Preimages - write_u32(&mut conn, request.preimages.len() as u32).await?; - - for (ty, preimages) in request.preimages.iter() { - write_u8(&mut conn, *ty as u8).await?; - write_u32(&mut conn, preimages.len() as u32).await?; - for (hash, preimage) in preimages { - write_exact(&mut conn, &hash.0).await?; - write_bytes(&mut conn, preimage).await?; - } - } - - // 9. Send User Wasms - let local_target = local_target(); - let local_user_wasm = request.user_wasms.get(local_target); - - // if there are user wasms, but only for wrong architecture - error - if local_user_wasm.is_none_or(|m| m.is_empty()) { - for (arch, wasms) in &request.user_wasms { - if !wasms.is_empty() { - return Err(anyhow!( - "bad stylus arch. got {arch}, expected {local_target}", - )); - } - } - } - - let empty_map = HashMap::new(); - let local_user_wasm = local_user_wasm.unwrap_or(&empty_map); - write_u32(&mut conn, local_user_wasm.len() as u32).await?; - for (module_hash, program) in local_user_wasm { - write_exact(&mut conn, &module_hash.0).await?; - write_bytes(&mut conn, &program.as_vec()).await?; - } - - // 10. Signal that we are done sending global state - write_u8(&mut conn, READY_BYTE).await?; + // 5. Send data + send_validation_input(&mut conn, request)?; - // 11. Read Response and return new state - let mut kind_buf = [0u8; 1]; - conn.read_exact(&mut kind_buf).await?; - - match kind_buf[0] { - FAILURE_BYTE => { - let msg_bytes = read_bytes_with_len(&mut conn).await?; - let msg = String::from_utf8_lossy(&msg_bytes); - error!("Jit Machine Failure message: {msg}"); - Err(anyhow!("Jit Machine Failure: {msg}")) - } - SUCCESS_BYTE => { - // We write the values to socket in BigEndian so we can use - // read_u64() directly from AsyncReadExt which handles - // BigEndian by default - state.batch = conn.read_u64().await?; - state.pos_in_batch = conn.read_u64().await?; - state.block_hash.0 = read_bytes32(&mut conn).await?; - state.send_root.0 = read_bytes32(&mut conn).await?; - - let memory_used = conn.read_u64().await?; + // 6. Read Response and return new state + match receive_response(&mut conn)? { + Ok((new_state, memory_used)) => { if memory_used > self.wasm_memory_usage_limit { warn!( "WARN: memory used {} exceeds limit {}", memory_used, self.wasm_memory_usage_limit ); } - - Ok(state) + Ok(new_state) + } + Err(err) => { + error!("Jit Machine Failure message: {err}"); + Err(anyhow!("Jit Machine Failure: {err}")) } - _ => Err(anyhow!("inter-process communication failure: unknown byte")), } }