Skip to content

[hyperactor] Serialized: permit multiple encoding representations #832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: gh/mariusae/35/base
Choose a base branch
from
132 changes: 76 additions & 56 deletions hyperactor/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ use std::fmt;
use std::io::Cursor;
use std::sync::LazyLock;

use enum_as_inner::EnumAsInner;
pub use intern_typename;
use serde::Deserialize;
use serde::Serialize;
Expand Down Expand Up @@ -291,10 +292,46 @@ macro_rules! register_type {
}

/// The encoding used for a serialized value.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
enum SerializedEncoding {
Bincode,
Json,
#[derive(Clone, Serialize, Deserialize, PartialEq, EnumAsInner)]
enum Encoded {
Bincode(serde_bytes::ByteBuf),
Json(serde_bytes::ByteBuf),
// todo: multipart
}

impl Encoded {
/// The length of the underlying serialized message
pub fn len(&self) -> usize {
match &self {
Encoded::Bincode(data) => data.len(),
Encoded::Json(data) => data.len(),
}
}

/// Is the message empty. This should always return false.
pub fn is_empty(&self) -> bool {
match &self {
Encoded::Bincode(data) => data.is_empty(),
Encoded::Json(data) => data.is_empty(),
}
}

/// Computes the 32bit crc of the encoded data
pub fn crc(&self) -> u32 {
match &self {
Encoded::Bincode(data) => crc32fast::hash(data),
Encoded::Json(data) => crc32fast::hash(data),
}
}
}

impl std::fmt::Debug for Encoded {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Encoded::Bincode(data) => write!(f, "Encoded::Bincode({})", HexFmt(data.as_slice())),
Encoded::Json(data) => write!(f, "Encoded::Json({})", HexFmt(data.as_slice())),
}
}
}

/// Represents a serialized value, wrapping the underlying serialization
Expand All @@ -303,24 +340,15 @@ enum SerializedEncoding {
///
/// Currently, Serialized passes through to bincode, but in the future we may include
/// content-encoding information to allow for other codecs as well.
#[derive(Clone, Serialize, Deserialize, PartialEq)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct Serialized {
encoding: SerializedEncoding,

/// The encoded data for the serialized value.
#[serde(with = "serde_bytes")]
data: Vec<u8>,
/// The encoded data
encoded: Encoded,
/// The typehash of the serialized value, if available. This is used to provide
/// typed introspection of the value.
typehash: Option<u64>,
}

impl std::fmt::Debug for Serialized {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Serialized({})", HexFmt(self.data.as_slice()),)
}
}

impl std::fmt::Display for Serialized {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.dump() {
Expand All @@ -331,7 +359,7 @@ impl std::fmt::Display for Serialized {
let basename = typename.split("::").last().unwrap_or(typename);
write!(f, "{}{}", basename, JsonFmt(&value))
}
Err(_) => write!(f, "{}", HexFmt(self.data.as_slice())),
Err(_) => write!(f, "{:?}", self.encoded),
}
}
}
Expand All @@ -340,38 +368,32 @@ impl Serialized {
/// Construct a new serialized value by serializing the provided T-typed value.
pub fn serialize<T: Serialize + Named>(value: &T) -> Result<Self, bincode::Error> {
Ok(Self {
encoding: SerializedEncoding::Bincode,
data: bincode::serialize(value)?,
encoded: Encoded::Bincode(bincode::serialize(value)?.into()),
typehash: Some(T::typehash()),
})
}

/// Construct a new anonymous (unnamed) serialized value by serializing the provided T-typed value.
pub fn serialize_anon<T: Serialize>(value: &T) -> Result<Self, bincode::Error> {
Ok(Self {
encoding: SerializedEncoding::Bincode,
data: bincode::serialize(value)?,
encoded: Encoded::Bincode(bincode::serialize(value)?.into()),
typehash: None,
})
}

/// Deserialize a value to the provided type T.
pub fn deserialized<T: DeserializeOwned>(&self) -> Result<T, anyhow::Error> {
match self.encoding {
SerializedEncoding::Bincode => {
bincode::deserialize(&self.data).map_err(anyhow::Error::from)
}
SerializedEncoding::Json => {
serde_json::from_slice(&self.data).map_err(anyhow::Error::from)
}
match &self.encoded {
Encoded::Bincode(data) => bincode::deserialize(data).map_err(anyhow::Error::from),
Encoded::Json(data) => serde_json::from_slice(data).map_err(anyhow::Error::from),
}
}

/// Transcode the serialized value to JSON. This operation will succeed if the type hash
/// is embedded in the value, and the corresponding type is available in this binary.
pub fn transcode_to_json(self) -> Result<Self, Self> {
match self.encoding {
SerializedEncoding::Bincode => {
match self.encoded {
Encoded::Bincode(_) => {
let json_value = match self.dump() {
Ok(json_value) => json_value,
Err(_) => return Err(self),
Expand All @@ -381,20 +403,19 @@ impl Serialized {
Err(_) => return Err(self),
};
Ok(Self {
encoding: SerializedEncoding::Json,
data: json_data,
encoded: Encoded::Json(json_data.into()),
typehash: self.typehash,
})
}
SerializedEncoding::Json => Ok(self),
Encoded::Json(_) => Ok(self),
}
}

/// Dump the Serialized message into a JSON value. This will succeed if: 1) the typehash is embedded
/// in the serialized value; 2) the named type is linked into the binary.
pub fn dump(&self) -> Result<serde_json::Value, anyhow::Error> {
match self.encoding {
SerializedEncoding::Bincode => {
match &self.encoded {
Encoded::Bincode(_) => {
let Some(typehash) = self.typehash() else {
anyhow::bail!("serialized value does not contain a typehash");
};
Expand All @@ -403,9 +424,7 @@ impl Serialized {
};
typeinfo.dump(self.clone())
}
SerializedEncoding::Json => {
serde_json::from_slice(&self.data).map_err(anyhow::Error::from)
}
Encoded::Json(data) => serde_json::from_slice(data).map_err(anyhow::Error::from),
}
}

Expand All @@ -425,11 +444,10 @@ impl Serialized {
// TODO: we should support this by formalizing the notion of a 'prefix'
// serialization, and generalize it to other codecs as well.
pub fn prefix<T: DeserializeOwned>(&self) -> Result<T, anyhow::Error> {
anyhow::ensure!(
self.encoding == SerializedEncoding::Bincode,
"only bincode supports prefix emplacement"
);
bincode::deserialize(&self.data).map_err(anyhow::Error::from)
match &self.encoded {
Encoded::Bincode(data) => bincode::deserialize(data).map_err(anyhow::Error::from),
_ => anyhow::bail!("only bincode supports prefix emplacement"),
}
}

/// Emplace a new prefix to this value. This is currently only supported
Expand All @@ -438,38 +456,39 @@ impl Serialized {
&mut self,
prefix: T,
) -> Result<(), anyhow::Error> {
anyhow::ensure!(
self.encoding == SerializedEncoding::Bincode,
"only bincode supports prefix emplacement"
);
let data = match &self.encoded {
Encoded::Bincode(data) => data,
_ => anyhow::bail!("only bincode supports prefix emplacement"),
};

// This is a bit ugly, but: we first deserialize out the old prefix,
// then serialize the new prefix, then splice the two together.
// This is safe because we know that the prefix is the first thing
// in the serialized value, and that the serialization format is stable.
let mut cursor = Cursor::new(self.data.clone());
let mut cursor = Cursor::new(data.clone());
let _prefix: T = bincode::deserialize_from(&mut cursor).unwrap();
let position = cursor.position() as usize;
let suffix = &cursor.into_inner()[position..];
self.data = bincode::serialize(&prefix)?;
self.data.extend_from_slice(suffix);
let mut data = bincode::serialize(&prefix)?;
data.extend_from_slice(suffix);
self.encoded = Encoded::Bincode(data.into());

Ok(())
}

/// The length of the underlying serialized message
pub fn len(&self) -> usize {
self.data.len()
self.encoded.len()
}

/// Is the message empty. This should always return false.
pub fn is_empty(&self) -> bool {
self.len() == 0
self.encoded.is_empty()
}

/// Returns the 32bit crc of the serialized data
pub fn crc(&self) -> u32 {
crc32fast::hash(&self.data)
self.encoded.crc()
}
}

Expand Down Expand Up @@ -615,10 +634,11 @@ mod tests {
let serialized = Serialized::serialize(&data).unwrap();
let serialized_json = serialized.clone().transcode_to_json().unwrap();

assert_eq!(serialized.encoding, SerializedEncoding::Bincode);
assert_eq!(serialized_json.encoding, SerializedEncoding::Json);
assert!(serialized.encoded.is_bincode());
assert!(serialized_json.encoded.is_json());

let json_string = String::from_utf8(serialized_json.data.clone()).unwrap();
let json_string =
String::from_utf8(serialized_json.encoded.as_json().unwrap().to_vec().clone()).unwrap();
// The serialized data for JSON is just the (compact) JSON string.
assert_eq!(json_string, "{\"a\":\"hello\",\"b\":1234,\"c\":5678}");

Expand Down