diff --git a/hyperactor/src/data.rs b/hyperactor/src/data.rs index a5342ccbc..c5c4df6ff 100644 --- a/hyperactor/src/data.rs +++ b/hyperactor/src/data.rs @@ -132,6 +132,7 @@ use std::fmt; use std::io::Cursor; use std::sync::LazyLock; +use enum_as_inner::EnumAsInner; pub use intern_typename; use serde::Deserialize; use serde::Serialize; @@ -291,10 +292,46 @@ macro_rules! register_type { } /// The encoding used for a serialized value. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -enum SerializedEncoding { - Bincode, - Json, +#[derive(Clone, Serialize, Deserialize, PartialEq, EnumAsInner)] +enum Encoded { + Bincode(serde_bytes::ByteBuf), + Json(serde_bytes::ByteBuf), + // todo: multipart +} + +impl Encoded { + /// The length of the underlying serialized message + pub fn len(&self) -> usize { + match &self { + Encoded::Bincode(data) => data.len(), + Encoded::Json(data) => data.len(), + } + } + + /// Is the message empty. This should always return false. + pub fn is_empty(&self) -> bool { + match &self { + Encoded::Bincode(data) => data.is_empty(), + Encoded::Json(data) => data.is_empty(), + } + } + + /// Computes the 32bit crc of the encoded data + pub fn crc(&self) -> u32 { + match &self { + Encoded::Bincode(data) => crc32fast::hash(data), + Encoded::Json(data) => crc32fast::hash(data), + } + } +} + +impl std::fmt::Debug for Encoded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Encoded::Bincode(data) => write!(f, "Encoded::Bincode({})", HexFmt(data.as_slice())), + Encoded::Json(data) => write!(f, "Encoded::Json({})", HexFmt(data.as_slice())), + } + } } /// Represents a serialized value, wrapping the underlying serialization @@ -303,24 +340,15 @@ enum SerializedEncoding { /// /// Currently, Serialized passes through to bincode, but in the future we may include /// content-encoding information to allow for other codecs as well. -#[derive(Clone, Serialize, Deserialize, PartialEq)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct Serialized { - encoding: SerializedEncoding, - - /// The encoded data for the serialized value. - #[serde(with = "serde_bytes")] - data: Vec, + /// The encoded data + encoded: Encoded, /// The typehash of the serialized value, if available. This is used to provide /// typed introspection of the value. typehash: Option, } -impl std::fmt::Debug for Serialized { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Serialized({})", HexFmt(self.data.as_slice()),) - } -} - impl std::fmt::Display for Serialized { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.dump() { @@ -331,7 +359,7 @@ impl std::fmt::Display for Serialized { let basename = typename.split("::").last().unwrap_or(typename); write!(f, "{}{}", basename, JsonFmt(&value)) } - Err(_) => write!(f, "{}", HexFmt(self.data.as_slice())), + Err(_) => write!(f, "{:?}", self.encoded), } } } @@ -340,8 +368,7 @@ impl Serialized { /// Construct a new serialized value by serializing the provided T-typed value. pub fn serialize(value: &T) -> Result { Ok(Self { - encoding: SerializedEncoding::Bincode, - data: bincode::serialize(value)?, + encoded: Encoded::Bincode(bincode::serialize(value)?.into()), typehash: Some(T::typehash()), }) } @@ -349,29 +376,24 @@ impl Serialized { /// Construct a new anonymous (unnamed) serialized value by serializing the provided T-typed value. pub fn serialize_anon(value: &T) -> Result { Ok(Self { - encoding: SerializedEncoding::Bincode, - data: bincode::serialize(value)?, + encoded: Encoded::Bincode(bincode::serialize(value)?.into()), typehash: None, }) } /// Deserialize a value to the provided type T. pub fn deserialized(&self) -> Result { - match self.encoding { - SerializedEncoding::Bincode => { - bincode::deserialize(&self.data).map_err(anyhow::Error::from) - } - SerializedEncoding::Json => { - serde_json::from_slice(&self.data).map_err(anyhow::Error::from) - } + match &self.encoded { + Encoded::Bincode(data) => bincode::deserialize(data).map_err(anyhow::Error::from), + Encoded::Json(data) => serde_json::from_slice(data).map_err(anyhow::Error::from), } } /// Transcode the serialized value to JSON. This operation will succeed if the type hash /// is embedded in the value, and the corresponding type is available in this binary. pub fn transcode_to_json(self) -> Result { - match self.encoding { - SerializedEncoding::Bincode => { + match self.encoded { + Encoded::Bincode(_) => { let json_value = match self.dump() { Ok(json_value) => json_value, Err(_) => return Err(self), @@ -381,20 +403,19 @@ impl Serialized { Err(_) => return Err(self), }; Ok(Self { - encoding: SerializedEncoding::Json, - data: json_data, + encoded: Encoded::Json(json_data.into()), typehash: self.typehash, }) } - SerializedEncoding::Json => Ok(self), + Encoded::Json(_) => Ok(self), } } /// Dump the Serialized message into a JSON value. This will succeed if: 1) the typehash is embedded /// in the serialized value; 2) the named type is linked into the binary. pub fn dump(&self) -> Result { - match self.encoding { - SerializedEncoding::Bincode => { + match &self.encoded { + Encoded::Bincode(_) => { let Some(typehash) = self.typehash() else { anyhow::bail!("serialized value does not contain a typehash"); }; @@ -403,9 +424,7 @@ impl Serialized { }; typeinfo.dump(self.clone()) } - SerializedEncoding::Json => { - serde_json::from_slice(&self.data).map_err(anyhow::Error::from) - } + Encoded::Json(data) => serde_json::from_slice(data).map_err(anyhow::Error::from), } } @@ -425,11 +444,10 @@ impl Serialized { // TODO: we should support this by formalizing the notion of a 'prefix' // serialization, and generalize it to other codecs as well. pub fn prefix(&self) -> Result { - anyhow::ensure!( - self.encoding == SerializedEncoding::Bincode, - "only bincode supports prefix emplacement" - ); - bincode::deserialize(&self.data).map_err(anyhow::Error::from) + match &self.encoded { + Encoded::Bincode(data) => bincode::deserialize(data).map_err(anyhow::Error::from), + _ => anyhow::bail!("only bincode supports prefix emplacement"), + } } /// Emplace a new prefix to this value. This is currently only supported @@ -438,38 +456,39 @@ impl Serialized { &mut self, prefix: T, ) -> Result<(), anyhow::Error> { - anyhow::ensure!( - self.encoding == SerializedEncoding::Bincode, - "only bincode supports prefix emplacement" - ); + let data = match &self.encoded { + Encoded::Bincode(data) => data, + _ => anyhow::bail!("only bincode supports prefix emplacement"), + }; // This is a bit ugly, but: we first deserialize out the old prefix, // then serialize the new prefix, then splice the two together. // This is safe because we know that the prefix is the first thing // in the serialized value, and that the serialization format is stable. - let mut cursor = Cursor::new(self.data.clone()); + let mut cursor = Cursor::new(data.clone()); let _prefix: T = bincode::deserialize_from(&mut cursor).unwrap(); let position = cursor.position() as usize; let suffix = &cursor.into_inner()[position..]; - self.data = bincode::serialize(&prefix)?; - self.data.extend_from_slice(suffix); + let mut data = bincode::serialize(&prefix)?; + data.extend_from_slice(suffix); + self.encoded = Encoded::Bincode(data.into()); Ok(()) } /// The length of the underlying serialized message pub fn len(&self) -> usize { - self.data.len() + self.encoded.len() } /// Is the message empty. This should always return false. pub fn is_empty(&self) -> bool { - self.len() == 0 + self.encoded.is_empty() } /// Returns the 32bit crc of the serialized data pub fn crc(&self) -> u32 { - crc32fast::hash(&self.data) + self.encoded.crc() } } @@ -615,10 +634,11 @@ mod tests { let serialized = Serialized::serialize(&data).unwrap(); let serialized_json = serialized.clone().transcode_to_json().unwrap(); - assert_eq!(serialized.encoding, SerializedEncoding::Bincode); - assert_eq!(serialized_json.encoding, SerializedEncoding::Json); + assert!(serialized.encoded.is_bincode()); + assert!(serialized_json.encoded.is_json()); - let json_string = String::from_utf8(serialized_json.data.clone()).unwrap(); + let json_string = + String::from_utf8(serialized_json.encoded.as_json().unwrap().to_vec().clone()).unwrap(); // The serialized data for JSON is just the (compact) JSON string. assert_eq!(json_string, "{\"a\":\"hello\",\"b\":1234,\"c\":5678}");