diff --git a/Cargo.toml b/Cargo.toml index e82f6e8..8e425ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,8 @@ path = "src/lib.rs" [dependencies] rand = { version = "0.9.0", features = [ "small_rng" ] } -ahash = "0.8.11" +ahash = { git = "https://github.com/alexeysofin/aHash.git", rev = "b917333a498311d193a8c2084e88dd18561b264f" } +serde = { version = "1.0.117", optional = true } thiserror = "2.0.11" [profile.release] @@ -33,7 +34,11 @@ criterion = "0.7.0" mockall = "0.13" clap = { version = "4.5.19", features = ["derive"] } memmap2 = "0.9.5" +bincode = "1.3.3" [[bench]] name = "topk_add" harness = false + +[features] +serde = ["ahash/serde", "serde/derive"] diff --git a/src/heavykeeper.rs b/src/heavykeeper.rs index e961e15..e2ba35c 100644 --- a/src/heavykeeper.rs +++ b/src/heavykeeper.rs @@ -9,15 +9,20 @@ use thiserror::Error; use crate::priority_queue::TopKQueue; use crate::hash_composition::HashComposer; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + const DECAY_LOOKUP_SIZE: usize = 1024; -#[derive(Default, Clone, Debug)] +#[derive(Default, Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] struct Bucket { fingerprint: u64, count: u64, } #[derive(Clone, PartialEq, Eq, Debug)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct Node { pub item: T, pub count: u64, @@ -69,6 +74,7 @@ pub enum BuilderError { MissingField { field: String }, } +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct TopK { top_items: usize, width: usize, @@ -78,9 +84,15 @@ pub struct TopK { buckets: Vec>, priority_queue: TopKQueue, hasher: RandomState, + #[cfg_attr(feature = "serde", serde(skip, default = "default_rng"))] random: Box, } +#[cfg(feature = "serde")] +fn default_rng() -> Box { + Box::new(SmallRng::seed_from_u64(0)) +} + pub struct Builder { k: Option, width: Option, @@ -294,6 +306,10 @@ impl TopK { nodes } + pub fn is_top(&self, key: &T) -> bool { + self.priority_queue.contains_key(key) + } + pub fn debug(&self) { println!("width: {}", self.width); println!("depth: {}", self.depth); @@ -523,6 +539,27 @@ mod tests { assert!(!topk.query(&absent), "Absent item should not be found"); } + /// Tests is_stop functionality + #[test] + fn test_is_top() { + let mut topk: TopK> = TopK::new(2, 100, 5, 0.9); + + let k1 = b"1".to_vec(); + let k2 = b"2".to_vec(); + let k3 = b"3".to_vec(); + + // Add the present item + topk.add(&k1, 10); + topk.add(&k2, 10); + topk.add(&k3, 1); + + // Verify query behavior + assert!(topk.is_top(&k1), "value is not in top list"); + assert!(topk.is_top(&k2), "value is not in top list"); + assert!(topk.is_top(&k2), "value is in top list"); + } + + /// Tests count functionality for items with varying frequencies #[test] fn test_count() { @@ -1251,5 +1288,30 @@ mod tests { assert!(topk.query(item)); assert_eq!(topk.count(item), 1); } -} + + #[test] + fn test_serialize() { + let mut topk: TopK = TopK::new(10, 100, 5, 0.9); + + for i in 0..10 { + topk.add(&format!("test{}", i), 10); + } + + let serialized = bincode::serialize(&topk).expect("failed to serialize"); + let deserialized: TopK = + bincode::deserialize(&serialized).expect("failed to deserialize"); + + assert_eq!(topk.top_items, deserialized.top_items); + assert_eq!(topk.width, deserialized.width); + assert_eq!(topk.decay, deserialized.decay); + assert_eq!(topk.decay_thresholds, deserialized.decay_thresholds); + + assert_eq!(topk.buckets, deserialized.buckets); + assert_eq!(topk.priority_queue, deserialized.priority_queue); + + // check merges work + + topk.merge(&deserialized).expect("merge failed"); + } +} \ No newline at end of file diff --git a/src/priority_queue.rs b/src/priority_queue.rs index c48170f..fa8c9d5 100644 --- a/src/priority_queue.rs +++ b/src/priority_queue.rs @@ -3,8 +3,16 @@ use std::collections::HashMap; use std::hash::Hash; use ahash::RandomState; +#[cfg(feature = "serde")] +use serde::{Serialize, Deserialize}; + /// A specialized priority queue for HeavyKeeper that maintains top-k items by count -pub(crate) struct TopKQueue { +#[cfg_attr( + feature = "serde", + derive(Deserialize, Serialize) +)] +#[derive(PartialEq, Debug)] +pub(crate) struct TopKQueue { items: HashMap, // item -> (count, heap_index) heap: Vec<(u64, usize, usize)>, // (count, sequence, item_index) item_store: Vec, // Store actual items here @@ -42,6 +50,14 @@ impl TopKQueue { self.items.get(item).map(|(count, _)| *count) } + pub(crate) fn contains_key(&self, item: &Q) -> bool + where + T: Borrow, + Q: Hash + Eq + ToOwned + ?Sized, + { + self.items.contains_key(item) + } + pub(crate) fn min_count(&self) -> u64 { // If heap is empty, return 0 // Otherwise return count from root node (index 0) @@ -197,6 +213,8 @@ mod tests { let items: Vec<_> = queue.iter().collect(); assert_eq!(items, vec![(&"b", 2), (&"a", 1)]); + + assert!(queue.contains_key(&"a")); } #[test] @@ -327,4 +345,4 @@ mod tests { items[i].1, items[i+1].1); } } -} +} \ No newline at end of file