Skip to content

Commit 4b1d20c

Browse files
committed
[struct_diff_patch] diffing / patching for rust structs
The crate struct_diff_patch provides "diff" and "patch" operations for rust data types. It defines two generic traits: 1) `Diff` which implements diffing two values to produce a patch; and 2) `Patch` which is implemented on patches in order to apply them to values. Additionally, the crate provides a "watch", which lets subscribers track changes to a value as a value snapshot + stream of patches. The intention is to use this in the monarch resource model to observe resources efficiently, without having to implement special operations for each access pattern. In the future, we can also implement patch _merge_ operations in order to allow patches to be used in accumulations as well. Differential Revision: [D87822791](https://our.internmc.facebook.com/intern/diff/D87822791/) ghstack-source-id: 325471593 Pull Request resolved: #1989
1 parent c7dae4c commit 4b1d20c

File tree

6 files changed

+833
-0
lines changed

6 files changed

+833
-0
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ members = [
2424
"preempt_rwlock",
2525
"rdmaxcel-sys",
2626
"serde_multipart",
27+
"struct_diff_patch",
28+
"struct_diff_patch_macros",
2729
"timed_test",
2830
"torch-sys",
2931
"torch-sys-cuda",

struct_diff_patch/Cargo.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# @generated by autocargo from //monarch/struct_diff_patch:struct_diff_patch
2+
3+
[package]
4+
name = "struct_diff_patch"
5+
version = "0.0.0"
6+
authors = ["Facebook <[email protected]>"]
7+
edition = "2021"
8+
description = "diff/patch for Rust structs"
9+
repository = "https://github.com/meta-pytorch/monarch/"
10+
license = "BSD-3-Clause"
11+
12+
[lib]
13+
edition = "2024"
14+
15+
[dependencies]
16+
paste = "1.0.14"
17+
struct_diff_patch_macros = { version = "0.0.0", path = "../struct_diff_patch_macros" }
18+
tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] }

struct_diff_patch/src/lib.rs

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//! This crate defines traits for diffing and patching Rust structs,
10+
//! implements these traits for common types, and provides macros for
11+
//! deriving them on structs.
12+
13+
pub mod watch;
14+
15+
use std::collections::HashMap;
16+
use std::collections::hash_map;
17+
use std::hash::Hash;
18+
19+
pub use struct_diff_patch_macros::Diff;
20+
pub use struct_diff_patch_macros::Patch;
21+
22+
/// Represents a patch operating targeting values of type `T`.
23+
pub trait Patch<T> {
24+
/// Apply this patch to the provided value, consuming the patch.
25+
fn apply(self, value: &mut T);
26+
}
27+
28+
/// Implements the "diff" operation, which produces a patch given
29+
/// two instances of the same type.
30+
pub trait Diff: Sized {
31+
/// The type of patch produced by this diff operation.
32+
type Patch: Patch<Self>;
33+
34+
/// Implements the "diff" operation, which produces a patch given
35+
/// two instances of the same type. Specifically, when the returned
36+
/// patch is applied to the original value, it should produce the
37+
/// second value.
38+
fn diff(&self, other: &Self) -> Self::Patch;
39+
}
40+
41+
impl<T> Patch<T> for Option<T> {
42+
fn apply(self, value: &mut T) {
43+
if let Some(new) = self {
44+
*value = new;
45+
}
46+
}
47+
}
48+
49+
impl<P, T> Patch<Vec<T>> for Vec<P>
50+
where
51+
P: Patch<T>,
52+
T: Default,
53+
{
54+
fn apply(self, value: &mut Vec<T>) {
55+
value.truncate(self.len());
56+
for (idx, patch) in self.into_iter().enumerate() {
57+
if idx < value.len() {
58+
patch.apply(&mut value[idx]);
59+
} else {
60+
value.push(T::default());
61+
patch.apply(&mut value[idx]);
62+
}
63+
}
64+
}
65+
}
66+
67+
impl<T: Diff + Clone + Default> Diff for Vec<T>
68+
where
69+
T::Patch: From<T>,
70+
{
71+
type Patch = Vec<T::Patch>;
72+
73+
fn diff(&self, other: &Self) -> Self::Patch {
74+
// Don't try to be clever here (e.g., using some kind of edit algorithm);
75+
// rather optimize for in-place edits, or just replace.
76+
//
77+
// Possibly we should also include prepend/append operations.
78+
let mut patch = Vec::with_capacity(other.len());
79+
for (idx, value) in other.iter().enumerate() {
80+
if idx < self.len() {
81+
patch.push(self[idx].diff(value));
82+
} else {
83+
patch.push(value.clone().into());
84+
}
85+
}
86+
patch
87+
}
88+
}
89+
90+
/// Vector of key edits. `None` denotes a key to be removed.
91+
pub type HashMapPatch<K, P> = Vec<(K, Option<P>)>;
92+
93+
impl<K, V, P> Patch<HashMap<K, V>> for HashMapPatch<K, P>
94+
where
95+
K: Eq + Hash,
96+
V: Default,
97+
P: Patch<V>,
98+
{
99+
fn apply(self, value: &mut HashMap<K, V>) {
100+
for (key, patch) in self {
101+
match patch {
102+
Some(patch) => match value.entry(key) {
103+
hash_map::Entry::Occupied(mut entry) => {
104+
patch.apply(entry.get_mut());
105+
}
106+
hash_map::Entry::Vacant(entry) => {
107+
let mut v = V::default();
108+
patch.apply(&mut v);
109+
entry.insert(v);
110+
}
111+
},
112+
None => {
113+
value.remove(&key);
114+
}
115+
}
116+
}
117+
}
118+
}
119+
120+
impl<K, V> Diff for HashMap<K, V>
121+
where
122+
K: Eq + Hash + Clone,
123+
V: Diff + Clone + Default,
124+
V::Patch: From<V>,
125+
{
126+
type Patch = HashMapPatch<K, V::Patch>;
127+
128+
fn diff(&self, other: &Self) -> Self::Patch {
129+
let mut changes = Vec::new();
130+
131+
for (key, new_value) in other.iter() {
132+
match self.get(key) {
133+
Some(value) => {
134+
changes.push((key.clone(), Some(value.diff(new_value))));
135+
}
136+
None => {
137+
changes.push((key.clone(), Some(new_value.clone().into())));
138+
}
139+
}
140+
}
141+
142+
for key in self.keys() {
143+
if !other.contains_key(key) {
144+
changes.push((key.clone(), None));
145+
}
146+
}
147+
148+
changes
149+
}
150+
}
151+
152+
#[macro_export]
153+
macro_rules! impl_simple_diff {
154+
($($ty:ty),+ $(,)?) => {
155+
$(
156+
impl $crate::Diff for $ty {
157+
type Patch = Option<$ty>;
158+
159+
fn diff(&self, other: &Self) -> Self::Patch {
160+
if self == other {
161+
None
162+
} else {
163+
Some(other.clone())
164+
}
165+
}
166+
}
167+
)+
168+
};
169+
}
170+
171+
impl_simple_diff!(
172+
(),
173+
bool,
174+
char,
175+
i8,
176+
i16,
177+
i32,
178+
i64,
179+
i128,
180+
isize,
181+
u8,
182+
u16,
183+
u32,
184+
u64,
185+
u128,
186+
usize,
187+
f32,
188+
f64,
189+
String
190+
);
191+
192+
#[macro_export]
193+
macro_rules! impl_tuple_diff_patch {
194+
($($idx:tt),+ $(,)?) => {
195+
::paste::paste! {
196+
impl<$( [<P$idx>], [<V$idx>] ),+> $crate::Patch<($( [<V$idx>], )+)> for ($( [<P$idx>], )+)
197+
where
198+
$( [<P$idx>]: $crate::Patch<[<V$idx>]>, )+
199+
{
200+
fn apply(self, value: &mut ($( [<V$idx>], )+)) {
201+
#[allow(non_snake_case)]
202+
let ($( [<patch_$idx>], )+) = self;
203+
$(
204+
[<patch_$idx>].apply(&mut value.$idx);
205+
)+
206+
}
207+
}
208+
209+
impl<$( [<T$idx>]: $crate::Diff ),+> $crate::Diff for ($( [<T$idx>], )+) {
210+
type Patch = ($( <[<T$idx>] as $crate::Diff>::Patch, )+);
211+
212+
fn diff(&self, other: &Self) -> Self::Patch {
213+
(
214+
$( self.$idx.diff(&other.$idx), )+
215+
)
216+
}
217+
}
218+
}
219+
};
220+
}
221+
222+
impl_tuple_diff_patch!(0);
223+
impl_tuple_diff_patch!(0, 1);
224+
impl_tuple_diff_patch!(0, 1, 2);
225+
impl_tuple_diff_patch!(0, 1, 2, 3);
226+
impl_tuple_diff_patch!(0, 1, 2, 3, 4);
227+
impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5);
228+
impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6);
229+
impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7);
230+
impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8);
231+
impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
232+
impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
233+
impl_tuple_diff_patch!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
234+
235+
#[cfg(test)]
236+
mod tests {
237+
use super::*;
238+
use crate as struct_diff_patch; // for macros
239+
240+
#[derive(Debug, Clone, PartialEq, Diff, Patch)]
241+
struct DerivedStruct {
242+
name: String,
243+
count: u32,
244+
}
245+
246+
#[derive(Debug, Clone, PartialEq, Diff, Patch)]
247+
struct DerivedTuple(String, bool);
248+
249+
#[derive(Debug, Clone, PartialEq, Diff, Patch)]
250+
struct DerivedUnit;
251+
252+
#[test]
253+
fn bool_diff_is_none_when_equal() {
254+
assert_eq!(false.diff(&false), None);
255+
}
256+
257+
#[test]
258+
fn string_diff_and_apply_replace_value() {
259+
let patch = String::from("bar").diff(&String::from("baz"));
260+
let mut value = String::from("bar");
261+
patch.apply(&mut value);
262+
assert_eq!(value, "baz");
263+
}
264+
265+
#[test]
266+
fn tuple_diff_tracks_each_field() {
267+
let original = (false, String::from("foo"));
268+
let target = (true, String::from("bar"));
269+
let patch = original.diff(&target);
270+
271+
let mut working = original;
272+
patch.apply(&mut working);
273+
assert_eq!(working, target);
274+
}
275+
276+
#[test]
277+
fn vec_patch() {
278+
let mut orig = vec![1, 2, 3, 4, 5];
279+
let target = vec![1, 20, 3, 40, 5];
280+
281+
let patch = orig.diff(&target);
282+
assert_eq!(patch, vec![None, Some(20), None, Some(40), None]);
283+
284+
patch.apply(&mut orig);
285+
assert_eq!(orig, target);
286+
}
287+
288+
#[test]
289+
fn hashmap_diff_patch_handles_insert_update_and_remove() {
290+
use std::collections::HashMap;
291+
292+
let mut original = HashMap::new();
293+
original.insert("keep".to_string(), 1_u32);
294+
original.insert("remove".to_string(), 2_u32);
295+
296+
let mut target = HashMap::new();
297+
target.insert("keep".to_string(), 10);
298+
target.insert("insert".to_string(), 3);
299+
300+
let patch = original.diff(&target);
301+
302+
let mut saw_insert = false;
303+
let mut saw_update = false;
304+
let mut saw_remove = false;
305+
306+
for (key, change) in patch.iter() {
307+
match (key.as_str(), change) {
308+
("insert", Some(Some(3))) => saw_insert = true,
309+
("keep", Some(Some(10))) => saw_update = true,
310+
("remove", None) => saw_remove = true,
311+
_ => {}
312+
}
313+
}
314+
315+
assert!(saw_insert);
316+
assert!(saw_update);
317+
assert!(saw_remove);
318+
319+
let mut working = original;
320+
patch.apply(&mut working);
321+
assert_eq!(working, target);
322+
}
323+
324+
#[test]
325+
fn derive_macro_generates_struct_and_patch_impls() {
326+
let mut original = DerivedStruct {
327+
name: "foo".into(),
328+
count: 1,
329+
};
330+
let target = DerivedStruct {
331+
name: "bar".into(),
332+
count: 2,
333+
};
334+
335+
let patch = original.diff(&target);
336+
patch.apply(&mut original);
337+
assert_eq!(original, target);
338+
339+
let tuple_patch = DerivedTuple("foo".into(), true).diff(&DerivedTuple("baz".into(), false));
340+
let mut tuple_value = DerivedTuple("foo".into(), true);
341+
tuple_patch.apply(&mut tuple_value);
342+
assert_eq!(tuple_value, DerivedTuple("baz".into(), false));
343+
344+
let mut unit = DerivedUnit;
345+
let unit_patch = unit.diff(&DerivedUnit);
346+
unit_patch.apply(&mut unit);
347+
}
348+
}

0 commit comments

Comments
 (0)