|
| 1 | +//! Aggregation Framework: generic group-by + fold utilities for reducer |
| 2 | +//! |
| 3 | +//! Goals |
| 4 | +//! - Provide a minimal, generic API to aggregate arbitrary items by a key `Range` and fold values into `T`. |
| 5 | +//! - Use closures (no traits) for projection and reduction to keep call sites simple and flexible. |
| 6 | +//! - Support both a simple one-shot API that returns a map and an advanced API that accumulates into an existing map |
| 7 | +//! with a custom initializer. |
| 8 | +//! |
| 9 | +//! Core idea |
| 10 | +//! - Inputs are arbitrary iterator items `E`. |
| 11 | +//! - A projection function maps an item to a key: `project(&E) -> Range`. |
| 12 | +//! - An add function folds each item into `T`: `add(&mut T, &E)`. |
| 13 | +//! - Aggregation iterates once, grouping by `Range` and reducing into a `HashMap<Range, T>`. |
| 14 | +//! |
| 15 | +//! Ownership & lifetimes |
| 16 | +//! - `project(&E) -> R` should produce an owned `R` (e.g., `String`, `u64`, tuples). Avoid returning references into |
| 17 | +//! ephemeral data inside `E` that won't live past the iteration. |
| 18 | +//! - If keys are large or expensive to clone, consider interning or using `Arc<str>`/`Arc<[u8]>` inside `project`. |
| 19 | +//! |
| 20 | +//! Complexity |
| 21 | +//! - Time: expected O(n) with the default hasher. |
| 22 | +//! - Memory: proportional to the number of distinct `Range` keys. |
| 23 | +//! - Iteration order of the returned map is unspecified. |
| 24 | +//! |
| 25 | +//! Example |
| 26 | +//! ``` |
| 27 | +//! use std::collections::HashMap; |
| 28 | +//! use reducer::aggregation_framework::aggregate; |
| 29 | +//! |
| 30 | +//! // Count by parity of u8 |
| 31 | +//! let data: Vec<(u8, ())> = vec![(1,()), (2,()), (4,()), (5,())]; |
| 32 | +//! let project = |(d, _s): &(u8, ())| d % 2; // key: even/odd |
| 33 | +//! let add = |t: &mut u64, _item: &(u8, ())| *t += 1; // count occurrences |
| 34 | +//! let by_parity: HashMap<u8, u64> = aggregate(data, project, add); |
| 35 | +//! assert_eq!(by_parity.get(&0), Some(&2)); // 2 and 4 |
| 36 | +//! assert_eq!(by_parity.get(&1), Some(&2)); // 1 and 5 |
| 37 | +//! ``` |
| 38 | +
|
| 39 | +use std::collections::HashMap; |
| 40 | +use std::hash::{BuildHasher, Hash}; |
| 41 | + |
| 42 | +/// Advanced aggregation: accumulate into an existing map with a custom initializer. |
| 43 | +/// |
| 44 | +/// - `dest`: destination map to accumulate into. |
| 45 | +/// - `iter`: items to aggregate. Consumed exactly once. |
| 46 | +/// - `project`: maps each item `E` to an owned key `R`. |
| 47 | +/// - `init`: constructs a fresh `T` when a new key is encountered. |
| 48 | +/// - `add`: folds each item into the corresponding `T`. |
| 49 | +pub fn aggregate_with<I, E, R, T, P, F, A, H>( |
| 50 | + dest: &mut HashMap<R, T, H>, |
| 51 | + iter: I, |
| 52 | + mut project: P, |
| 53 | + mut init: F, |
| 54 | + mut add: A, |
| 55 | +) where |
| 56 | + I: IntoIterator<Item = E>, |
| 57 | + P: FnMut(&E) -> R, |
| 58 | + F: FnMut() -> T, |
| 59 | + A: FnMut(&mut T, &E), |
| 60 | + R: Eq + Hash, |
| 61 | + H: BuildHasher, |
| 62 | +{ |
| 63 | + let it = iter.into_iter(); |
| 64 | + let (lower, _) = it.size_hint(); |
| 65 | + dest.reserve(lower); |
| 66 | + for item in it { |
| 67 | + let r = project(&item); |
| 68 | + let entry = dest.entry(r).or_insert_with(&mut init); |
| 69 | + add(entry, &item); |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +/// Simple aggregation: one-shot group-by + fold that returns a fresh `HashMap`. |
| 74 | +/// |
| 75 | +/// This is sugar over `aggregate_with`, using `T: Default` for initialization and the default hasher. |
| 76 | +pub fn aggregate<I, E, R, T, P, A>(iter: I, project: P, add: A) -> HashMap<R, T> |
| 77 | +where |
| 78 | + I: IntoIterator<Item = E>, |
| 79 | + P: FnMut(&E) -> R, |
| 80 | + A: FnMut(&mut T, &E), |
| 81 | + R: Eq + Hash, |
| 82 | + T: Default, |
| 83 | +{ |
| 84 | + let inner = iter.into_iter(); |
| 85 | + let (lower, _) = inner.size_hint(); |
| 86 | + let mut map: HashMap<R, T> = HashMap::with_capacity(lower); |
| 87 | + aggregate_with(&mut map, inner, project, T::default, add); |
| 88 | + map |
| 89 | +} |
| 90 | + |
| 91 | +#[cfg(test)] |
| 92 | +mod tests { |
| 93 | + use super::*; |
| 94 | + use proptest::collection::vec; |
| 95 | + use proptest::prelude::*; |
| 96 | + |
| 97 | + #[test] |
| 98 | + fn empty_iterator_returns_empty_map() { |
| 99 | + let data: Vec<(u32, u32)> = vec![]; |
| 100 | + let m: HashMap<u32, u64> = aggregate( |
| 101 | + data, |
| 102 | + |(d, _s): &(u32, u32)| *d, |
| 103 | + |t: &mut u64, (_d, s): &(u32, u32)| *t += *s as u64, |
| 104 | + ); |
| 105 | + assert!(m.is_empty()); |
| 106 | + } |
| 107 | + |
| 108 | + #[test] |
| 109 | + fn single_key_all_elements_aggregated() { |
| 110 | + let data: Vec<(u8, u16)> = (0..10).map(|i| (i, 1u16)).collect(); |
| 111 | + let m: HashMap<u8, u64> = |
| 112 | + aggregate(data, |_item: &(u8, u16)| 0u8, |t, (_d, s)| *t += *s as u64); |
| 113 | + assert_eq!(m.len(), 1); |
| 114 | + assert_eq!(m.get(&0u8), Some(&10u64)); |
| 115 | + } |
| 116 | + |
| 117 | + #[test] |
| 118 | + fn distinct_keys_each_once() { |
| 119 | + let data: Vec<(u32, u16)> = (0..100).map(|i| (i, 1u16)).collect(); |
| 120 | + let m: HashMap<u32, u64> = aggregate( |
| 121 | + data, |
| 122 | + |(d, _s): &(u32, u16)| *d, |
| 123 | + |t, (_d, s)| *t += *s as u64, |
| 124 | + ); |
| 125 | + assert_eq!(m.len(), 100); |
| 126 | + for i in 0u32..100 { |
| 127 | + assert_eq!(m.get(&i), Some(&1u64)); |
| 128 | + } |
| 129 | + } |
| 130 | + |
| 131 | + #[test] |
| 132 | + fn aggregate_with_custom_init_non_default_type() { |
| 133 | + #[derive(Debug, PartialEq, Eq)] |
| 134 | + struct NonDefault { |
| 135 | + sum: u64, |
| 136 | + flag: bool, |
| 137 | + } |
| 138 | + |
| 139 | + let data = vec![(1u8, 5u16), (1u8, 7u16), (2u8, 10u16)]; |
| 140 | + let mut dest: HashMap<u8, NonDefault> = HashMap::new(); |
| 141 | + aggregate_with( |
| 142 | + &mut dest, |
| 143 | + data, |
| 144 | + |(d, _s): &(u8, u16)| *d, |
| 145 | + || NonDefault { sum: 1, flag: true }, // custom init |
| 146 | + |t: &mut NonDefault, (_d, s): &(u8, u16)| t.sum += *s as u64, |
| 147 | + ); |
| 148 | + |
| 149 | + assert_eq!(dest.len(), 2); |
| 150 | + assert_eq!( |
| 151 | + dest.get(&1u8), |
| 152 | + Some(&NonDefault { |
| 153 | + sum: 1 + 5 + 7, |
| 154 | + flag: true |
| 155 | + }) |
| 156 | + ); |
| 157 | + assert_eq!( |
| 158 | + dest.get(&2u8), |
| 159 | + Some(&NonDefault { |
| 160 | + sum: 1 + 10, |
| 161 | + flag: true |
| 162 | + }) |
| 163 | + ); |
| 164 | + } |
| 165 | + |
| 166 | + proptest! { |
| 167 | + #[test] |
| 168 | + fn proptest_equivalence_to_naive_fold(values in vec((any::<u8>(), any::<u16>()), 0..200), k in 1u8..=8u8) { |
| 169 | + // naive fold |
| 170 | + let mut naive: HashMap<u8, u64> = HashMap::new(); |
| 171 | + for (d, s) in &values { |
| 172 | + let key = d % k; |
| 173 | + *naive.entry(key).or_insert(0) += *s as u64; |
| 174 | + } |
| 175 | + |
| 176 | + // aggregate helper |
| 177 | + let project = |(d, _s): &(u8, u16)| d % k; |
| 178 | + let add = |t: &mut u64, (_d, s): &(u8, u16)| *t += *s as u64; |
| 179 | + let got: HashMap<u8, u64> = aggregate(values.clone(), project, add); |
| 180 | + |
| 181 | + prop_assert_eq!(got, naive); |
| 182 | + } |
| 183 | + |
| 184 | + #[test] |
| 185 | + fn proptest_chunked_aggregation_matches_one_shot(values in vec((any::<u8>(), any::<u16>()), 0..200), k in 1u8..=8u8) { |
| 186 | + let mid = values.len().saturating_div(2); |
| 187 | + let (left, right) = values.split_at(mid); |
| 188 | + |
| 189 | + let mut dest: HashMap<u8, u64> = HashMap::new(); |
| 190 | + let project = |(d, _s): &(u8, u16)| d % k; |
| 191 | + let add = |t: &mut u64, (_d, s): &(u8, u16)| *t += *s as u64; |
| 192 | + |
| 193 | + aggregate_with(&mut dest, left.to_vec(), project, || 0u64, add); |
| 194 | + aggregate_with(&mut dest, right.to_vec(), |(d, _s): &(u8, u16)| d % k, || 0u64, |t, (_d, s): &(u8, u16)| *t += *s as u64); |
| 195 | + |
| 196 | + // one-shot |
| 197 | + let one_shot: HashMap<u8, u64> = aggregate(values.clone(), |(d, _s): &(u8, u16)| d % k, |t, (_d, s): &(u8, u16)| *t += *s as u64); |
| 198 | + |
| 199 | + prop_assert_eq!(dest, one_shot); |
| 200 | + } |
| 201 | + } |
| 202 | +} |
0 commit comments