diff --git a/Cargo.toml b/Cargo.toml index 1b755b781..a363d2ec9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,8 @@ members = [ "preempt_rwlock", "rdmaxcel-sys", "serde_multipart", + "struct_diff_patch", + "struct_diff_patch_macros", "timed_test", "torch-sys", "torch-sys-cuda", diff --git a/struct_diff_patch/Cargo.toml b/struct_diff_patch/Cargo.toml new file mode 100644 index 000000000..b667cfebb --- /dev/null +++ b/struct_diff_patch/Cargo.toml @@ -0,0 +1,18 @@ +# @generated by autocargo from //monarch/struct_diff_patch:struct_diff_patch + +[package] +name = "struct_diff_patch" +version = "0.0.0" +authors = ["Facebook "] +edition = "2021" +description = "diff/patch for Rust structs" +repository = "https://github.com/meta-pytorch/monarch/" +license = "BSD-3-Clause" + +[lib] +edition = "2024" + +[dependencies] +paste = "1.0.14" +struct_diff_patch_macros = { version = "0.0.0", path = "../struct_diff_patch_macros" } +tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] } diff --git a/struct_diff_patch/src/lib.rs b/struct_diff_patch/src/lib.rs new file mode 100644 index 000000000..23d0a502e --- /dev/null +++ b/struct_diff_patch/src/lib.rs @@ -0,0 +1,348 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! This crate defines traits for diffing and patching Rust structs, +//! implements these traits for common types, and provides macros for +//! deriving them on structs. + +pub mod watch; + +use std::collections::HashMap; +use std::collections::hash_map; +use std::hash::Hash; + +pub use struct_diff_patch_macros::Diff; +pub use struct_diff_patch_macros::Patch; + +/// Represents a patch operating targeting values of type `T`. +pub trait Patch { + /// Apply this patch to the provided value, consuming the patch. + fn apply(self, value: &mut T); +} + +/// Implements the "diff" operation, which produces a patch given +/// two instances of the same type. +pub trait Diff: Sized { + /// The type of patch produced by this diff operation. + type Patch: Patch; + + /// Implements the "diff" operation, which produces a patch given + /// two instances of the same type. Specifically, when the returned + /// patch is applied to the original value, it should produce the + /// second value. + fn diff(&self, other: &Self) -> Self::Patch; +} + +impl Patch for Option { + fn apply(self, value: &mut T) { + if let Some(new) = self { + *value = new; + } + } +} + +impl Patch> for Vec

+where + P: Patch, + T: Default, +{ + fn apply(self, value: &mut Vec) { + value.truncate(self.len()); + for (idx, patch) in self.into_iter().enumerate() { + if idx < value.len() { + patch.apply(&mut value[idx]); + } else { + value.push(T::default()); + patch.apply(&mut value[idx]); + } + } + } +} + +impl Diff for Vec +where + T::Patch: From, +{ + type Patch = Vec; + + fn diff(&self, other: &Self) -> Self::Patch { + // Don't try to be clever here (e.g., using some kind of edit algorithm); + // rather optimize for in-place edits, or just replace. + // + // Possibly we should also include prepend/append operations. + let mut patch = Vec::with_capacity(other.len()); + for (idx, value) in other.iter().enumerate() { + if idx < self.len() { + patch.push(self[idx].diff(value)); + } else { + patch.push(value.clone().into()); + } + } + patch + } +} + +/// Vector of key edits. `None` denotes a key to be removed. +pub type HashMapPatch = Vec<(K, Option

)>; + +impl Patch> for HashMapPatch +where + K: Eq + Hash, + V: Default, + P: Patch, +{ + fn apply(self, value: &mut HashMap) { + for (key, patch) in self { + match patch { + Some(patch) => match value.entry(key) { + hash_map::Entry::Occupied(mut entry) => { + patch.apply(entry.get_mut()); + } + hash_map::Entry::Vacant(entry) => { + let mut v = V::default(); + patch.apply(&mut v); + entry.insert(v); + } + }, + None => { + value.remove(&key); + } + } + } + } +} + +impl Diff for HashMap +where + K: Eq + Hash + Clone, + V: Diff + Clone + Default, + V::Patch: From, +{ + type Patch = HashMapPatch; + + fn diff(&self, other: &Self) -> Self::Patch { + let mut changes = Vec::new(); + + for (key, new_value) in other.iter() { + match self.get(key) { + Some(value) => { + changes.push((key.clone(), Some(value.diff(new_value)))); + } + None => { + changes.push((key.clone(), Some(new_value.clone().into()))); + } + } + } + + for key in self.keys() { + if !other.contains_key(key) { + changes.push((key.clone(), None)); + } + } + + changes + } +} + +#[macro_export] +macro_rules! impl_simple_diff { + ($($ty:ty),+ $(,)?) => { + $( + impl $crate::Diff for $ty { + type Patch = Option<$ty>; + + fn diff(&self, other: &Self) -> Self::Patch { + if self == other { + None + } else { + Some(other.clone()) + } + } + } + )+ + }; +} + +impl_simple_diff!( + (), + bool, + char, + i8, + i16, + i32, + i64, + i128, + isize, + u8, + u16, + u32, + u64, + u128, + usize, + f32, + f64, + String +); + +#[macro_export] +macro_rules! impl_tuple_diff_patch { + ($($idx:tt),+ $(,)?) => { + ::paste::paste! { + impl<$( [], [] ),+> $crate::Patch<($( [], )+)> for ($( [], )+) + where + $( []: $crate::Patch<[]>, )+ + { + fn apply(self, value: &mut ($( [], )+)) { + #[allow(non_snake_case)] + let ($( [], )+) = self; + $( + [].apply(&mut value.$idx); + )+ + } + } + + impl<$( []: $crate::Diff ),+> $crate::Diff for ($( [], )+) { + type Patch = ($( <[] as $crate::Diff>::Patch, )+); + + fn diff(&self, other: &Self) -> Self::Patch { + ( + $( self.$idx.diff(&other.$idx), )+ + ) + } + } + } + }; +} + +impl_tuple_diff_patch!(0); +impl_tuple_diff_patch!(0, 1); +impl_tuple_diff_patch!(0, 1, 2); +impl_tuple_diff_patch!(0, 1, 2, 3); +impl_tuple_diff_patch!(0, 1, 2, 3, 4); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + +#[cfg(test)] +mod tests { + use super::*; + use crate as struct_diff_patch; // for macros + + #[derive(Debug, Clone, PartialEq, Diff, Patch)] + struct DerivedStruct { + name: String, + count: u32, + } + + #[derive(Debug, Clone, PartialEq, Diff, Patch)] + struct DerivedTuple(String, bool); + + #[derive(Debug, Clone, PartialEq, Diff, Patch)] + struct DerivedUnit; + + #[test] + fn bool_diff_is_none_when_equal() { + assert_eq!(false.diff(&false), None); + } + + #[test] + fn string_diff_and_apply_replace_value() { + let patch = String::from("bar").diff(&String::from("baz")); + let mut value = String::from("bar"); + patch.apply(&mut value); + assert_eq!(value, "baz"); + } + + #[test] + fn tuple_diff_tracks_each_field() { + let original = (false, String::from("foo")); + let target = (true, String::from("bar")); + let patch = original.diff(&target); + + let mut working = original; + patch.apply(&mut working); + assert_eq!(working, target); + } + + #[test] + fn vec_patch() { + let mut orig = vec![1, 2, 3, 4, 5]; + let target = vec![1, 20, 3, 40, 5]; + + let patch = orig.diff(&target); + assert_eq!(patch, vec![None, Some(20), None, Some(40), None]); + + patch.apply(&mut orig); + assert_eq!(orig, target); + } + + #[test] + fn hashmap_diff_patch_handles_insert_update_and_remove() { + use std::collections::HashMap; + + let mut original = HashMap::new(); + original.insert("keep".to_string(), 1_u32); + original.insert("remove".to_string(), 2_u32); + + let mut target = HashMap::new(); + target.insert("keep".to_string(), 10); + target.insert("insert".to_string(), 3); + + let patch = original.diff(&target); + + let mut saw_insert = false; + let mut saw_update = false; + let mut saw_remove = false; + + for (key, change) in patch.iter() { + match (key.as_str(), change) { + ("insert", Some(Some(3))) => saw_insert = true, + ("keep", Some(Some(10))) => saw_update = true, + ("remove", None) => saw_remove = true, + _ => {} + } + } + + assert!(saw_insert); + assert!(saw_update); + assert!(saw_remove); + + let mut working = original; + patch.apply(&mut working); + assert_eq!(working, target); + } + + #[test] + fn derive_macro_generates_struct_and_patch_impls() { + let mut original = DerivedStruct { + name: "foo".into(), + count: 1, + }; + let target = DerivedStruct { + name: "bar".into(), + count: 2, + }; + + let patch = original.diff(&target); + patch.apply(&mut original); + assert_eq!(original, target); + + let tuple_patch = DerivedTuple("foo".into(), true).diff(&DerivedTuple("baz".into(), false)); + let mut tuple_value = DerivedTuple("foo".into(), true); + tuple_patch.apply(&mut tuple_value); + assert_eq!(tuple_value, DerivedTuple("baz".into(), false)); + + let mut unit = DerivedUnit; + let unit_patch = unit.diff(&DerivedUnit); + unit_patch.apply(&mut unit); + } +} diff --git a/struct_diff_patch/src/watch.rs b/struct_diff_patch/src/watch.rs new file mode 100644 index 000000000..22d4e28ea --- /dev/null +++ b/struct_diff_patch/src/watch.rs @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! Provides cells that can provide snapshot and update streams. + +use std::ops::Deref; +use std::ops::DerefMut; + +use tokio::sync::broadcast; + +use crate::Diff; + +/// A watch provides stateful value updates, sending incremental patches +/// to a set of subscribers. +pub struct Watch +where + T::Patch: Clone, +{ + value: T, + sender: broadcast::Sender, +} + +impl Watch +where + T::Patch: Clone, +{ + /// Create a new watch holding the provided initial value. + pub fn new(value: T) -> Self { + let (sender, _) = broadcast::channel(1024); + Watch { value, sender } + } + + /// The current value of the watch. + pub fn value(&self) -> &T { + &self.value + } + + /// Subscribe to the watch's value updates. The broadcast receiver has a maximum + /// buffer size of 1024. If a subscriber is lagging (see [`tokio::sync::broadcast::Receiver`] + /// for details), is the subscriber's responsibility to re-establish the subscription. + /// + /// Typically, subscribers will read the current value, clone it, and then subscribe to the watch + /// for further updates. + pub fn subscribe(&self) -> broadcast::Receiver { + self.sender.subscribe() + } +} + +impl Watch +where + T::Patch: Clone, +{ + /// Returns a guard that allows mutating the value and publishes the diff when dropped. + /// We could define a 'Mutator' type to avoid the extra clone here. + /// + /// Currently, watch updates require cloning the value in order to compute the update + /// (comparing old and new values). We could expand this crate's repertoir by providing + /// a "Mutator" trait (and derive macro) that tracks mutations directly into a patch, + /// avoiding the clone. + pub fn update(&mut self) -> WatchUpdate<'_, T> { + WatchUpdate { + original: self.value.clone(), + value: &mut self.value, + sender: &self.sender, + } + } +} + +/// An update guard (like a "Ref") used to borrow the watch while updating its value. +pub struct WatchUpdate<'a, T: Diff + Clone> +where + T::Patch: Clone, +{ + original: T, + value: &'a mut T, + sender: &'a broadcast::Sender, +} + +impl<'a, T: Diff + Clone> Deref for WatchUpdate<'a, T> +where + T::Patch: Clone, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: Diff + Clone> DerefMut for WatchUpdate<'a, T> +where + T::Patch: Clone, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + self.value + } +} + +impl<'a, T: Diff + Clone> Drop for WatchUpdate<'a, T> +where + T::Patch: Clone, +{ + fn drop(&mut self) { + let patch = self.original.diff(self.value); + let _ = self.sender.send(patch); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate as struct_diff_patch; // for macro expansion + use crate::Patch; + + #[derive(Debug, Clone, PartialEq, Diff, Patch)] + struct TestStruct { + name: String, + count: u32, + } + + #[tokio::test] + async fn update_sends_patch_on_drop() { + let mut watch = Watch::new(TestStruct { + name: "start".into(), + count: 1, + }); + let mut rx = watch.subscribe(); + + { + let mut u = watch.update(); + u.name = "end".into(); + u.count = 2; + } + + let patch = rx.recv().await.unwrap(); + + let mut value = TestStruct { + name: "start".into(), + count: 1, + }; + patch.apply(&mut value); + + assert_eq!( + value, + TestStruct { + name: "end".into(), + count: 2 + } + ); + assert_eq!(watch.value().name, "end"); + assert_eq!(watch.value().count, 2); + } +} diff --git a/struct_diff_patch_macros/Cargo.toml b/struct_diff_patch_macros/Cargo.toml new file mode 100644 index 000000000..886fa18b9 --- /dev/null +++ b/struct_diff_patch_macros/Cargo.toml @@ -0,0 +1,21 @@ +# @generated by autocargo from //monarch/struct_diff_patch_macros:struct_diff_patch_macros + +[package] +name = "struct_diff_patch_macros" +version = "0.0.0" +authors = ["Facebook "] +edition = "2021" +description = "derive macros for struct_diff_patch" +repository = "https://github.com/meta-pytorch/monarch/" +license = "BSD-3-Clause" + +[lib] +test = false +doctest = false +proc-macro = true +edition = "2024" + +[dependencies] +proc-macro2 = { version = "1.0.70", features = ["span-locations"] } +quote = "1.0.29" +syn = { version = "2.0.110", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } diff --git a/struct_diff_patch_macros/src/lib.rs b/struct_diff_patch_macros/src/lib.rs new file mode 100644 index 000000000..643051cf9 --- /dev/null +++ b/struct_diff_patch_macros/src/lib.rs @@ -0,0 +1,286 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use proc_macro::TokenStream; +use proc_macro2::Span; +use proc_macro2::TokenStream as TokenStream2; +use quote::format_ident; +use quote::quote; +use syn::Data; +use syn::DeriveInput; +use syn::Fields; +use syn::Index; +use syn::Type; +use syn::parse_macro_input; +use syn::parse_quote; + +/// Generates a [`struct_diff_patch::Diff`] implementation for a struct. +/// The patch type will match the one expected by the struct's [`struct_diff_patch::Patch`] implementation, +/// as derived by the [`Patch`] derive macro. +/// +/// For example, +/// +/// ```ignore +/// #[derive(Diff)] +/// struct MyStruct { +/// name: String, +/// count: u32, +/// } +/// ``` +/// +/// will generate the following [`struct_diff_patch::Diff`] implementation: +/// +/// ```ignore +/// impl struct_diff_patch::Diff for MyStruct +/// where +/// String: struct_diff_patch::Diff, +/// u32: struct_diff_patch::Diff, +/// { +/// type Patch = ( +/// ::Patch, +/// ::Patch, +/// ); +/// fn diff(&self, other: &Self) -> Self::Patch { +/// (self.name.diff(&other.name), self.count.diff(&other.count)) +/// } +/// } +/// ``` +#[proc_macro_derive(Diff)] +pub fn derive_diff(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + expand_diff(input) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} + +/// Derives a [`struct_diff_patch::Patch`] implementation for a struct. +/// The patch type will match the one expected by the struct's [`struct_diff_patch::Diff`] implementation, +/// as derived by the [`Diff`] derive macro. +/// +/// For example, +/// +/// ```ignore +/// #[derive(Patch)] +/// struct MyStruct { +/// name: String, +/// count: u32, +/// } +/// ``` +/// +/// will generate the following: +/// +/// ```ignore +/// impl struct_diff_patch::Patch +/// for (::Patch, ::Patch) +/// where +/// String: struct_diff_patch::Diff, +/// u32: struct_diff_patch::Diff, +/// { +/// fn apply(self, value: &mut MyStruct) { +/// let (field_patch_0, field_patch_1) = self; +/// field_patch_0.apply(&mut value.name); +/// field_patch_1.apply(&mut value.count); +/// } +/// } +/// ``` +#[proc_macro_derive(Patch)] +pub fn derive_patch(input: TokenStream) -> TokenStream { + let _ = parse_macro_input!(input as DeriveInput); + TokenStream::new() +} + +fn expand_diff(input: DeriveInput) -> syn::Result { + let DeriveInput { + ident, + generics, + data, + .. + } = input; + let crate_path = crate_path(); + + let (_, ty_generics, _) = generics.split_for_impl(); + let mut impl_generics = generics.clone(); + + let struct_tokens = match data { + Data::Struct(data_struct) => data_struct, + _ => { + return Err(syn::Error::new( + Span::call_site(), + "Diff derive currently supports only structs", + )); + } + }; + + let (patch_type, diff_expr, apply_body, field_types) = + build_struct_parts(&struct_tokens.fields, &crate_path); + + if !field_types.is_empty() { + let where_clause_impl = impl_generics.make_where_clause(); + for field_ty in field_types { + where_clause_impl + .predicates + .push(parse_quote! { #field_ty: #crate_path::Diff }); + } + } + + let (impl_generics_tokens, _, where_clause_tokens) = impl_generics.split_for_impl(); + + Ok(quote! { + impl #impl_generics_tokens #crate_path::Diff for #ident #ty_generics #where_clause_tokens { + type Patch = #patch_type; + + fn diff(&self, other: &Self) -> Self::Patch { + #diff_expr + } + } + + impl #impl_generics_tokens #crate_path::Patch<#ident #ty_generics> for #patch_type #where_clause_tokens { + fn apply(self, value: &mut #ident #ty_generics) { + #apply_body + } + } + }) +} + +fn build_struct_parts( + fields: &Fields, + crate_path: &syn::Path, +) -> (TokenStream2, TokenStream2, TokenStream2, Vec) { + match fields { + Fields::Named(named) => { + let names: Vec<_> = named + .named + .iter() + .map(|field| field.ident.clone().expect("named field")) + .collect(); + let types: Vec<_> = named.named.iter().map(|field| field.ty.clone()).collect(); + + let patch_types: Vec<_> = types + .iter() + .map(|ty| quote! { <#ty as #crate_path::Diff>::Patch }) + .collect(); + + let diff_fields: Vec<_> = names + .iter() + .map(|name| quote! { self.#name.diff(&other.#name) }) + .collect(); + + let binding_names: Vec<_> = names + .iter() + .enumerate() + .map(|(pos, _)| format_ident!("field_patch_{pos}")) + .collect(); + + let apply_steps = binding_names + .iter() + .zip(names.iter()) + .map(|(binding, name)| quote! { #binding.apply(&mut value.#name); }); + + let patch_type = if patch_types.len() > 0 { + quote! { ( #( #patch_types ),* , ) } + } else { + quote! { () } + }; + + let diff_expr = if diff_fields.len() > 0 { + quote! { ( #( #diff_fields ),* , ) } + } else { + quote! { () } + }; + + let apply_body = if binding_names.is_empty() { + quote! { + let _ = self; + let _ = value; + } + } else { + quote! { + let ( #( #binding_names ),* , ) = self; + #( #apply_steps )* + } + }; + + (patch_type, diff_expr, apply_body, types) + } + Fields::Unnamed(unnamed) => { + let types: Vec<_> = unnamed + .unnamed + .iter() + .map(|field| field.ty.clone()) + .collect(); + let indices: Vec<_> = (0..types.len()).map(Index::from).collect(); + let patch_fields: Vec<_> = types + .iter() + .map(|ty| quote! { <#ty as #crate_path::Diff>::Patch }) + .collect(); + + let diff_fields: Vec<_> = indices + .iter() + .map(|idx| quote! { self.#idx.diff(&other.#idx) }) + .collect(); + + let binding_names: Vec<_> = indices + .iter() + .enumerate() + .map(|(pos, _)| format_ident!("field_patch_{pos}")) + .collect(); + + let apply_steps = binding_names + .iter() + .zip(indices.iter()) + .map(|(binding, idx)| quote! { #binding.apply(&mut value.#idx); }); + + let patch_type = if patch_fields.len() > 0 { + quote! { ( #( #patch_fields ),* , ) } + } else { + quote! { () } + }; + let diff_expr = if diff_fields.len() > 0 { + quote! { ( #( #diff_fields ),* , ) } + } else { + quote! { () } + }; + + let apply_body = if binding_names.is_empty() { + quote! { + let _ = self; + let _ = value; + } + } else { + quote! { + let ( #( #binding_names ),* , ) = self; + #( #apply_steps )* + } + }; + + (patch_type, diff_expr, apply_body, types) + } + Fields::Unit => { + let patch_type = quote! { () }; + let diff_expr = quote! { () }; + let apply_body = quote! { + let _ = self; + let _ = value; + }; + (patch_type, diff_expr, apply_body, Vec::new()) + } + } +} + +fn crate_path() -> syn::Path { + syn::parse_quote!(struct_diff_patch) + // TODO: get proc-macro-crate into third-party + // match crate_name("struct_diff_patch") { + // Ok(FoundCrate::Itself) => syn::parse_quote!(crate), + // Ok(FoundCrate::Name(name)) => { + // let ident = Ident::new(&name, Span::call_site()); + // syn::parse_quote!(#ident) + // } + // Err(_) => syn::parse_quote!(struct_diff_patch), + // } +}