From 1cbb7ff80f09bced3df2e5e033df54aaa992dae3 Mon Sep 17 00:00:00 2001 From: Marius Eriksen Date: Mon, 24 Nov 2025 14:03:56 -0800 Subject: [PATCH] [struct_diff_patch] diffing / patching for rust structs The crate struct_diff_patch provides "diff" and "patch" operations for rust data types. It defines two generic traits: 1) `Diff` which implements diffing two values to produce a patch; and 2) `Patch` which is implemented on patches in order to apply them to values. Additionally, the crate provides a "watch", which lets subscribers track changes to a value as a value snapshot + stream of patches. The intention is to use this in the monarch resource model to observe resources efficiently, without having to implement special operations for each access pattern. In the future, we can also implement patch _merge_ operations in order to allow patches to be used in accumulations as well. Differential Revision: [D87822791](https://our.internmc.facebook.com/intern/diff/D87822791/) [ghstack-poisoned] --- Cargo.toml | 2 + struct_diff_patch/Cargo.toml | 18 ++ struct_diff_patch/src/lib.rs | 348 ++++++++++++++++++++++++++++ struct_diff_patch/src/watch.rs | 158 +++++++++++++ struct_diff_patch_macros/Cargo.toml | 21 ++ struct_diff_patch_macros/src/lib.rs | 286 +++++++++++++++++++++++ 6 files changed, 833 insertions(+) create mode 100644 struct_diff_patch/Cargo.toml create mode 100644 struct_diff_patch/src/lib.rs create mode 100644 struct_diff_patch/src/watch.rs create mode 100644 struct_diff_patch_macros/Cargo.toml create mode 100644 struct_diff_patch_macros/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 1b755b781..a363d2ec9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,8 @@ members = [ "preempt_rwlock", "rdmaxcel-sys", "serde_multipart", + "struct_diff_patch", + "struct_diff_patch_macros", "timed_test", "torch-sys", "torch-sys-cuda", diff --git a/struct_diff_patch/Cargo.toml b/struct_diff_patch/Cargo.toml new file mode 100644 index 000000000..b667cfebb --- /dev/null +++ b/struct_diff_patch/Cargo.toml @@ -0,0 +1,18 @@ +# @generated by autocargo from //monarch/struct_diff_patch:struct_diff_patch + +[package] +name = "struct_diff_patch" +version = "0.0.0" +authors = ["Facebook "] +edition = "2021" +description = "diff/patch for Rust structs" +repository = "https://github.com/meta-pytorch/monarch/" +license = "BSD-3-Clause" + +[lib] +edition = "2024" + +[dependencies] +paste = "1.0.14" +struct_diff_patch_macros = { version = "0.0.0", path = "../struct_diff_patch_macros" } +tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] } diff --git a/struct_diff_patch/src/lib.rs b/struct_diff_patch/src/lib.rs new file mode 100644 index 000000000..23d0a502e --- /dev/null +++ b/struct_diff_patch/src/lib.rs @@ -0,0 +1,348 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! This crate defines traits for diffing and patching Rust structs, +//! implements these traits for common types, and provides macros for +//! deriving them on structs. + +pub mod watch; + +use std::collections::HashMap; +use std::collections::hash_map; +use std::hash::Hash; + +pub use struct_diff_patch_macros::Diff; +pub use struct_diff_patch_macros::Patch; + +/// Represents a patch operating targeting values of type `T`. +pub trait Patch { + /// Apply this patch to the provided value, consuming the patch. + fn apply(self, value: &mut T); +} + +/// Implements the "diff" operation, which produces a patch given +/// two instances of the same type. +pub trait Diff: Sized { + /// The type of patch produced by this diff operation. + type Patch: Patch; + + /// Implements the "diff" operation, which produces a patch given + /// two instances of the same type. Specifically, when the returned + /// patch is applied to the original value, it should produce the + /// second value. + fn diff(&self, other: &Self) -> Self::Patch; +} + +impl Patch for Option { + fn apply(self, value: &mut T) { + if let Some(new) = self { + *value = new; + } + } +} + +impl Patch> for Vec

+where + P: Patch, + T: Default, +{ + fn apply(self, value: &mut Vec) { + value.truncate(self.len()); + for (idx, patch) in self.into_iter().enumerate() { + if idx < value.len() { + patch.apply(&mut value[idx]); + } else { + value.push(T::default()); + patch.apply(&mut value[idx]); + } + } + } +} + +impl Diff for Vec +where + T::Patch: From, +{ + type Patch = Vec; + + fn diff(&self, other: &Self) -> Self::Patch { + // Don't try to be clever here (e.g., using some kind of edit algorithm); + // rather optimize for in-place edits, or just replace. + // + // Possibly we should also include prepend/append operations. + let mut patch = Vec::with_capacity(other.len()); + for (idx, value) in other.iter().enumerate() { + if idx < self.len() { + patch.push(self[idx].diff(value)); + } else { + patch.push(value.clone().into()); + } + } + patch + } +} + +/// Vector of key edits. `None` denotes a key to be removed. +pub type HashMapPatch = Vec<(K, Option

)>; + +impl Patch> for HashMapPatch +where + K: Eq + Hash, + V: Default, + P: Patch, +{ + fn apply(self, value: &mut HashMap) { + for (key, patch) in self { + match patch { + Some(patch) => match value.entry(key) { + hash_map::Entry::Occupied(mut entry) => { + patch.apply(entry.get_mut()); + } + hash_map::Entry::Vacant(entry) => { + let mut v = V::default(); + patch.apply(&mut v); + entry.insert(v); + } + }, + None => { + value.remove(&key); + } + } + } + } +} + +impl Diff for HashMap +where + K: Eq + Hash + Clone, + V: Diff + Clone + Default, + V::Patch: From, +{ + type Patch = HashMapPatch; + + fn diff(&self, other: &Self) -> Self::Patch { + let mut changes = Vec::new(); + + for (key, new_value) in other.iter() { + match self.get(key) { + Some(value) => { + changes.push((key.clone(), Some(value.diff(new_value)))); + } + None => { + changes.push((key.clone(), Some(new_value.clone().into()))); + } + } + } + + for key in self.keys() { + if !other.contains_key(key) { + changes.push((key.clone(), None)); + } + } + + changes + } +} + +#[macro_export] +macro_rules! impl_simple_diff { + ($($ty:ty),+ $(,)?) => { + $( + impl $crate::Diff for $ty { + type Patch = Option<$ty>; + + fn diff(&self, other: &Self) -> Self::Patch { + if self == other { + None + } else { + Some(other.clone()) + } + } + } + )+ + }; +} + +impl_simple_diff!( + (), + bool, + char, + i8, + i16, + i32, + i64, + i128, + isize, + u8, + u16, + u32, + u64, + u128, + usize, + f32, + f64, + String +); + +#[macro_export] +macro_rules! impl_tuple_diff_patch { + ($($idx:tt),+ $(,)?) => { + ::paste::paste! { + impl<$( [], [] ),+> $crate::Patch<($( [], )+)> for ($( [], )+) + where + $( []: $crate::Patch<[]>, )+ + { + fn apply(self, value: &mut ($( [], )+)) { + #[allow(non_snake_case)] + let ($( [], )+) = self; + $( + [].apply(&mut value.$idx); + )+ + } + } + + impl<$( []: $crate::Diff ),+> $crate::Diff for ($( [], )+) { + type Patch = ($( <[] as $crate::Diff>::Patch, )+); + + fn diff(&self, other: &Self) -> Self::Patch { + ( + $( self.$idx.diff(&other.$idx), )+ + ) + } + } + } + }; +} + +impl_tuple_diff_patch!(0); +impl_tuple_diff_patch!(0, 1); +impl_tuple_diff_patch!(0, 1, 2); +impl_tuple_diff_patch!(0, 1, 2, 3); +impl_tuple_diff_patch!(0, 1, 2, 3, 4); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); +impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + +#[cfg(test)] +mod tests { + use super::*; + use crate as struct_diff_patch; // for macros + + #[derive(Debug, Clone, PartialEq, Diff, Patch)] + struct DerivedStruct { + name: String, + count: u32, + } + + #[derive(Debug, Clone, PartialEq, Diff, Patch)] + struct DerivedTuple(String, bool); + + #[derive(Debug, Clone, PartialEq, Diff, Patch)] + struct DerivedUnit; + + #[test] + fn bool_diff_is_none_when_equal() { + assert_eq!(false.diff(&false), None); + } + + #[test] + fn string_diff_and_apply_replace_value() { + let patch = String::from("bar").diff(&String::from("baz")); + let mut value = String::from("bar"); + patch.apply(&mut value); + assert_eq!(value, "baz"); + } + + #[test] + fn tuple_diff_tracks_each_field() { + let original = (false, String::from("foo")); + let target = (true, String::from("bar")); + let patch = original.diff(&target); + + let mut working = original; + patch.apply(&mut working); + assert_eq!(working, target); + } + + #[test] + fn vec_patch() { + let mut orig = vec![1, 2, 3, 4, 5]; + let target = vec![1, 20, 3, 40, 5]; + + let patch = orig.diff(&target); + assert_eq!(patch, vec![None, Some(20), None, Some(40), None]); + + patch.apply(&mut orig); + assert_eq!(orig, target); + } + + #[test] + fn hashmap_diff_patch_handles_insert_update_and_remove() { + use std::collections::HashMap; + + let mut original = HashMap::new(); + original.insert("keep".to_string(), 1_u32); + original.insert("remove".to_string(), 2_u32); + + let mut target = HashMap::new(); + target.insert("keep".to_string(), 10); + target.insert("insert".to_string(), 3); + + let patch = original.diff(&target); + + let mut saw_insert = false; + let mut saw_update = false; + let mut saw_remove = false; + + for (key, change) in patch.iter() { + match (key.as_str(), change) { + ("insert", Some(Some(3))) => saw_insert = true, + ("keep", Some(Some(10))) => saw_update = true, + ("remove", None) => saw_remove = true, + _ => {} + } + } + + assert!(saw_insert); + assert!(saw_update); + assert!(saw_remove); + + let mut working = original; + patch.apply(&mut working); + assert_eq!(working, target); + } + + #[test] + fn derive_macro_generates_struct_and_patch_impls() { + let mut original = DerivedStruct { + name: "foo".into(), + count: 1, + }; + let target = DerivedStruct { + name: "bar".into(), + count: 2, + }; + + let patch = original.diff(&target); + patch.apply(&mut original); + assert_eq!(original, target); + + let tuple_patch = DerivedTuple("foo".into(), true).diff(&DerivedTuple("baz".into(), false)); + let mut tuple_value = DerivedTuple("foo".into(), true); + tuple_patch.apply(&mut tuple_value); + assert_eq!(tuple_value, DerivedTuple("baz".into(), false)); + + let mut unit = DerivedUnit; + let unit_patch = unit.diff(&DerivedUnit); + unit_patch.apply(&mut unit); + } +} diff --git a/struct_diff_patch/src/watch.rs b/struct_diff_patch/src/watch.rs new file mode 100644 index 000000000..22d4e28ea --- /dev/null +++ b/struct_diff_patch/src/watch.rs @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! Provides cells that can provide snapshot and update streams. + +use std::ops::Deref; +use std::ops::DerefMut; + +use tokio::sync::broadcast; + +use crate::Diff; + +/// A watch provides stateful value updates, sending incremental patches +/// to a set of subscribers. +pub struct Watch +where + T::Patch: Clone, +{ + value: T, + sender: broadcast::Sender, +} + +impl Watch +where + T::Patch: Clone, +{ + /// Create a new watch holding the provided initial value. + pub fn new(value: T) -> Self { + let (sender, _) = broadcast::channel(1024); + Watch { value, sender } + } + + /// The current value of the watch. + pub fn value(&self) -> &T { + &self.value + } + + /// Subscribe to the watch's value updates. The broadcast receiver has a maximum + /// buffer size of 1024. If a subscriber is lagging (see [`tokio::sync::broadcast::Receiver`] + /// for details), is the subscriber's responsibility to re-establish the subscription. + /// + /// Typically, subscribers will read the current value, clone it, and then subscribe to the watch + /// for further updates. + pub fn subscribe(&self) -> broadcast::Receiver { + self.sender.subscribe() + } +} + +impl Watch +where + T::Patch: Clone, +{ + /// Returns a guard that allows mutating the value and publishes the diff when dropped. + /// We could define a 'Mutator' type to avoid the extra clone here. + /// + /// Currently, watch updates require cloning the value in order to compute the update + /// (comparing old and new values). We could expand this crate's repertoir by providing + /// a "Mutator" trait (and derive macro) that tracks mutations directly into a patch, + /// avoiding the clone. + pub fn update(&mut self) -> WatchUpdate<'_, T> { + WatchUpdate { + original: self.value.clone(), + value: &mut self.value, + sender: &self.sender, + } + } +} + +/// An update guard (like a "Ref") used to borrow the watch while updating its value. +pub struct WatchUpdate<'a, T: Diff + Clone> +where + T::Patch: Clone, +{ + original: T, + value: &'a mut T, + sender: &'a broadcast::Sender, +} + +impl<'a, T: Diff + Clone> Deref for WatchUpdate<'a, T> +where + T::Patch: Clone, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: Diff + Clone> DerefMut for WatchUpdate<'a, T> +where + T::Patch: Clone, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + self.value + } +} + +impl<'a, T: Diff + Clone> Drop for WatchUpdate<'a, T> +where + T::Patch: Clone, +{ + fn drop(&mut self) { + let patch = self.original.diff(self.value); + let _ = self.sender.send(patch); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate as struct_diff_patch; // for macro expansion + use crate::Patch; + + #[derive(Debug, Clone, PartialEq, Diff, Patch)] + struct TestStruct { + name: String, + count: u32, + } + + #[tokio::test] + async fn update_sends_patch_on_drop() { + let mut watch = Watch::new(TestStruct { + name: "start".into(), + count: 1, + }); + let mut rx = watch.subscribe(); + + { + let mut u = watch.update(); + u.name = "end".into(); + u.count = 2; + } + + let patch = rx.recv().await.unwrap(); + + let mut value = TestStruct { + name: "start".into(), + count: 1, + }; + patch.apply(&mut value); + + assert_eq!( + value, + TestStruct { + name: "end".into(), + count: 2 + } + ); + assert_eq!(watch.value().name, "end"); + assert_eq!(watch.value().count, 2); + } +} diff --git a/struct_diff_patch_macros/Cargo.toml b/struct_diff_patch_macros/Cargo.toml new file mode 100644 index 000000000..886fa18b9 --- /dev/null +++ b/struct_diff_patch_macros/Cargo.toml @@ -0,0 +1,21 @@ +# @generated by autocargo from //monarch/struct_diff_patch_macros:struct_diff_patch_macros + +[package] +name = "struct_diff_patch_macros" +version = "0.0.0" +authors = ["Facebook "] +edition = "2021" +description = "derive macros for struct_diff_patch" +repository = "https://github.com/meta-pytorch/monarch/" +license = "BSD-3-Clause" + +[lib] +test = false +doctest = false +proc-macro = true +edition = "2024" + +[dependencies] +proc-macro2 = { version = "1.0.70", features = ["span-locations"] } +quote = "1.0.29" +syn = { version = "2.0.110", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } diff --git a/struct_diff_patch_macros/src/lib.rs b/struct_diff_patch_macros/src/lib.rs new file mode 100644 index 000000000..643051cf9 --- /dev/null +++ b/struct_diff_patch_macros/src/lib.rs @@ -0,0 +1,286 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use proc_macro::TokenStream; +use proc_macro2::Span; +use proc_macro2::TokenStream as TokenStream2; +use quote::format_ident; +use quote::quote; +use syn::Data; +use syn::DeriveInput; +use syn::Fields; +use syn::Index; +use syn::Type; +use syn::parse_macro_input; +use syn::parse_quote; + +/// Generates a [`struct_diff_patch::Diff`] implementation for a struct. +/// The patch type will match the one expected by the struct's [`struct_diff_patch::Patch`] implementation, +/// as derived by the [`Patch`] derive macro. +/// +/// For example, +/// +/// ```ignore +/// #[derive(Diff)] +/// struct MyStruct { +/// name: String, +/// count: u32, +/// } +/// ``` +/// +/// will generate the following [`struct_diff_patch::Diff`] implementation: +/// +/// ```ignore +/// impl struct_diff_patch::Diff for MyStruct +/// where +/// String: struct_diff_patch::Diff, +/// u32: struct_diff_patch::Diff, +/// { +/// type Patch = ( +/// ::Patch, +/// ::Patch, +/// ); +/// fn diff(&self, other: &Self) -> Self::Patch { +/// (self.name.diff(&other.name), self.count.diff(&other.count)) +/// } +/// } +/// ``` +#[proc_macro_derive(Diff)] +pub fn derive_diff(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + expand_diff(input) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} + +/// Derives a [`struct_diff_patch::Patch`] implementation for a struct. +/// The patch type will match the one expected by the struct's [`struct_diff_patch::Diff`] implementation, +/// as derived by the [`Diff`] derive macro. +/// +/// For example, +/// +/// ```ignore +/// #[derive(Patch)] +/// struct MyStruct { +/// name: String, +/// count: u32, +/// } +/// ``` +/// +/// will generate the following: +/// +/// ```ignore +/// impl struct_diff_patch::Patch +/// for (::Patch, ::Patch) +/// where +/// String: struct_diff_patch::Diff, +/// u32: struct_diff_patch::Diff, +/// { +/// fn apply(self, value: &mut MyStruct) { +/// let (field_patch_0, field_patch_1) = self; +/// field_patch_0.apply(&mut value.name); +/// field_patch_1.apply(&mut value.count); +/// } +/// } +/// ``` +#[proc_macro_derive(Patch)] +pub fn derive_patch(input: TokenStream) -> TokenStream { + let _ = parse_macro_input!(input as DeriveInput); + TokenStream::new() +} + +fn expand_diff(input: DeriveInput) -> syn::Result { + let DeriveInput { + ident, + generics, + data, + .. + } = input; + let crate_path = crate_path(); + + let (_, ty_generics, _) = generics.split_for_impl(); + let mut impl_generics = generics.clone(); + + let struct_tokens = match data { + Data::Struct(data_struct) => data_struct, + _ => { + return Err(syn::Error::new( + Span::call_site(), + "Diff derive currently supports only structs", + )); + } + }; + + let (patch_type, diff_expr, apply_body, field_types) = + build_struct_parts(&struct_tokens.fields, &crate_path); + + if !field_types.is_empty() { + let where_clause_impl = impl_generics.make_where_clause(); + for field_ty in field_types { + where_clause_impl + .predicates + .push(parse_quote! { #field_ty: #crate_path::Diff }); + } + } + + let (impl_generics_tokens, _, where_clause_tokens) = impl_generics.split_for_impl(); + + Ok(quote! { + impl #impl_generics_tokens #crate_path::Diff for #ident #ty_generics #where_clause_tokens { + type Patch = #patch_type; + + fn diff(&self, other: &Self) -> Self::Patch { + #diff_expr + } + } + + impl #impl_generics_tokens #crate_path::Patch<#ident #ty_generics> for #patch_type #where_clause_tokens { + fn apply(self, value: &mut #ident #ty_generics) { + #apply_body + } + } + }) +} + +fn build_struct_parts( + fields: &Fields, + crate_path: &syn::Path, +) -> (TokenStream2, TokenStream2, TokenStream2, Vec) { + match fields { + Fields::Named(named) => { + let names: Vec<_> = named + .named + .iter() + .map(|field| field.ident.clone().expect("named field")) + .collect(); + let types: Vec<_> = named.named.iter().map(|field| field.ty.clone()).collect(); + + let patch_types: Vec<_> = types + .iter() + .map(|ty| quote! { <#ty as #crate_path::Diff>::Patch }) + .collect(); + + let diff_fields: Vec<_> = names + .iter() + .map(|name| quote! { self.#name.diff(&other.#name) }) + .collect(); + + let binding_names: Vec<_> = names + .iter() + .enumerate() + .map(|(pos, _)| format_ident!("field_patch_{pos}")) + .collect(); + + let apply_steps = binding_names + .iter() + .zip(names.iter()) + .map(|(binding, name)| quote! { #binding.apply(&mut value.#name); }); + + let patch_type = if patch_types.len() > 0 { + quote! { ( #( #patch_types ),* , ) } + } else { + quote! { () } + }; + + let diff_expr = if diff_fields.len() > 0 { + quote! { ( #( #diff_fields ),* , ) } + } else { + quote! { () } + }; + + let apply_body = if binding_names.is_empty() { + quote! { + let _ = self; + let _ = value; + } + } else { + quote! { + let ( #( #binding_names ),* , ) = self; + #( #apply_steps )* + } + }; + + (patch_type, diff_expr, apply_body, types) + } + Fields::Unnamed(unnamed) => { + let types: Vec<_> = unnamed + .unnamed + .iter() + .map(|field| field.ty.clone()) + .collect(); + let indices: Vec<_> = (0..types.len()).map(Index::from).collect(); + let patch_fields: Vec<_> = types + .iter() + .map(|ty| quote! { <#ty as #crate_path::Diff>::Patch }) + .collect(); + + let diff_fields: Vec<_> = indices + .iter() + .map(|idx| quote! { self.#idx.diff(&other.#idx) }) + .collect(); + + let binding_names: Vec<_> = indices + .iter() + .enumerate() + .map(|(pos, _)| format_ident!("field_patch_{pos}")) + .collect(); + + let apply_steps = binding_names + .iter() + .zip(indices.iter()) + .map(|(binding, idx)| quote! { #binding.apply(&mut value.#idx); }); + + let patch_type = if patch_fields.len() > 0 { + quote! { ( #( #patch_fields ),* , ) } + } else { + quote! { () } + }; + let diff_expr = if diff_fields.len() > 0 { + quote! { ( #( #diff_fields ),* , ) } + } else { + quote! { () } + }; + + let apply_body = if binding_names.is_empty() { + quote! { + let _ = self; + let _ = value; + } + } else { + quote! { + let ( #( #binding_names ),* , ) = self; + #( #apply_steps )* + } + }; + + (patch_type, diff_expr, apply_body, types) + } + Fields::Unit => { + let patch_type = quote! { () }; + let diff_expr = quote! { () }; + let apply_body = quote! { + let _ = self; + let _ = value; + }; + (patch_type, diff_expr, apply_body, Vec::new()) + } + } +} + +fn crate_path() -> syn::Path { + syn::parse_quote!(struct_diff_patch) + // TODO: get proc-macro-crate into third-party + // match crate_name("struct_diff_patch") { + // Ok(FoundCrate::Itself) => syn::parse_quote!(crate), + // Ok(FoundCrate::Name(name)) => { + // let ident = Ident::new(&name, Span::call_site()); + // syn::parse_quote!(#ident) + // } + // Err(_) => syn::parse_quote!(struct_diff_patch), + // } +}