diff --git a/hyperactor/src/data.rs b/hyperactor/src/data.rs index d8c9ab32f..8fb806f83 100644 --- a/hyperactor/src/data.rs +++ b/hyperactor/src/data.rs @@ -302,7 +302,7 @@ macro_rules! register_type { enum Encoded { Bincode(serde_bytes::ByteBuf), Json(serde_bytes::ByteBuf), - // todo: multipart + Multipart(serde_multipart::Message), } impl Encoded { @@ -311,6 +311,7 @@ impl Encoded { match &self { Encoded::Bincode(data) => data.len(), Encoded::Json(data) => data.len(), + Encoded::Multipart(message) => message.len(), } } @@ -319,6 +320,7 @@ impl Encoded { match &self { Encoded::Bincode(data) => data.is_empty(), Encoded::Json(data) => data.is_empty(), + Encoded::Multipart(message) => message.is_empty(), } } @@ -327,6 +329,14 @@ impl Encoded { match &self { Encoded::Bincode(data) => crc32fast::hash(data), Encoded::Json(data) => crc32fast::hash(data), + Encoded::Multipart(message) => { + let mut hasher = crc32fast::Hasher::new(); + hasher.update(message.body().as_ref()); + for part in message.parts() { + hasher.update(part.as_ref()); + } + hasher.finalize() + } } } } @@ -336,6 +346,7 @@ impl std::fmt::Debug for Encoded { match self { Encoded::Bincode(data) => write!(f, "Encoded::Bincode({})", HexFmt(data.as_slice())), Encoded::Json(data) => write!(f, "Encoded::Json({})", HexFmt(data.as_slice())), + Encoded::Multipart(message) => todo!(), //write!(f, "Encoded::Multipart({})", HexFmt(data.as_slice())), } } } @@ -392,6 +403,9 @@ impl Serialized { match &self.encoded { Encoded::Bincode(data) => bincode::deserialize(data).map_err(anyhow::Error::from), Encoded::Json(data) => serde_json::from_slice(data).map_err(anyhow::Error::from), + Encoded::Multipart(message) => { + serde_multipart::deserialize_bincode(message.clone()).map_err(anyhow::Error::from) + } } } @@ -399,7 +413,7 @@ impl Serialized { /// is embedded in the value, and the corresponding type is available in this binary. pub fn transcode_to_json(self) -> Result { match self.encoded { - Encoded::Bincode(_) => { + Encoded::Bincode(_) | Encoded::Multipart(_) => { let json_value = match self.dump() { Ok(json_value) => json_value, Err(_) => return Err(self), @@ -431,6 +445,10 @@ impl Serialized { typeinfo.dump(self.clone()) } Encoded::Json(data) => serde_json::from_slice(data).map_err(anyhow::Error::from), + Encoded::Multipart(_) => { + // TODO: implement typeinfo.dump_multipart + anyhow::bail!("dumping multipart-encoded values is not yet supported") + } } }