Skip to content

Commit 73244e9

Browse files
committed
migrate tests for counter & tdigest
1 parent 4795d13 commit 73244e9

File tree

4 files changed

+316
-0
lines changed

4 files changed

+316
-0
lines changed

optd-cost-model/Cargo.lock

Lines changed: 43 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

optd-cost-model/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ chrono = "0.4"
1616
itertools = "0.13"
1717
lazy_static = "1.5"
1818

19+
[dev-dependencies]
20+
crossbeam = "0.8"
21+
rand = "0.8"

optd-cost-model/src/stats/counter.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,128 @@ where
6969
self.counts.contains_key(key)
7070
}
7171
}
72+
73+
#[cfg(test)]
74+
mod tests {
75+
use std::collections::HashMap;
76+
use std::sync::{Arc, Mutex};
77+
78+
use crossbeam::thread;
79+
use rand::rngs::StdRng;
80+
use rand::seq::SliceRandom;
81+
use rand::SeedableRng;
82+
83+
use super::Counter;
84+
85+
// Generates hardcoded frequencies and returns them,
86+
// along with a flattened randomized array containing those frequencies.
87+
fn generate_frequencies() -> (HashMap<i32, i32>, Vec<i32>) {
88+
let mut frequencies = HashMap::new();
89+
90+
frequencies.insert(0, 2);
91+
frequencies.insert(1, 4);
92+
frequencies.insert(2, 9);
93+
frequencies.insert(3, 8);
94+
frequencies.insert(4, 50);
95+
frequencies.insert(5, 6);
96+
97+
let mut flattened = Vec::new();
98+
for (key, &value) in &frequencies {
99+
for _ in 0..value {
100+
flattened.push(*key);
101+
}
102+
}
103+
104+
let mut rng = StdRng::seed_from_u64(0);
105+
flattened.shuffle(&mut rng);
106+
107+
(frequencies, flattened)
108+
}
109+
110+
#[test]
111+
fn aggregate() {
112+
let to_track = vec![0, 1, 2, 3];
113+
let mut mcv = Counter::<i32>::new(&to_track);
114+
115+
let (frequencies, flattened) = generate_frequencies();
116+
117+
mcv.aggregate(&flattened);
118+
119+
let mcv_freq = mcv.frequencies();
120+
assert_eq!(mcv_freq.len(), to_track.len());
121+
122+
to_track.iter().for_each(|item| {
123+
assert!(mcv_freq.contains_key(item));
124+
assert_eq!(
125+
mcv_freq.get(item),
126+
frequencies
127+
.get(item)
128+
.map(|e| (*e as f64 / flattened.len() as f64))
129+
.as_ref()
130+
);
131+
});
132+
}
133+
134+
#[test]
135+
fn merge() {
136+
let to_track = vec![0, 1, 2, 3];
137+
let n_jobs = 16;
138+
139+
let total_frequencies = Arc::new(Mutex::new(HashMap::<i32, i32>::new()));
140+
let total_count = Arc::new(Mutex::new(0));
141+
let result_mcv = Arc::new(Mutex::new(Counter::<i32>::new(&to_track)));
142+
thread::scope(|s| {
143+
for _ in 0..n_jobs {
144+
s.spawn(|_| {
145+
let mut local_mcv = Counter::<i32>::new(&to_track);
146+
147+
let (local_frequencies, flattened) = generate_frequencies();
148+
let mut total_frequencies = total_frequencies.lock().unwrap();
149+
let mut total_count = total_count.lock().unwrap();
150+
for (&key, &value) in &local_frequencies {
151+
*total_frequencies.entry(key).or_insert(0) += value;
152+
*total_count += value;
153+
}
154+
155+
local_mcv.aggregate(&flattened);
156+
157+
let mcv_local_freq = local_mcv.frequencies();
158+
assert_eq!(mcv_local_freq.len(), to_track.len());
159+
160+
to_track.iter().for_each(|item| {
161+
assert!(mcv_local_freq.contains_key(item));
162+
assert_eq!(
163+
mcv_local_freq.get(item),
164+
local_frequencies
165+
.get(item)
166+
.map(|e| (*e as f64 / flattened.len() as f64))
167+
.as_ref()
168+
);
169+
});
170+
171+
let mut result = result_mcv.lock().unwrap();
172+
result.merge(&local_mcv);
173+
});
174+
}
175+
})
176+
.unwrap();
177+
178+
let mcv = result_mcv.lock().unwrap();
179+
let total_count = total_count.lock().unwrap();
180+
let mcv_freq = mcv.frequencies();
181+
182+
assert_eq!(*total_count, mcv.total_count);
183+
to_track.iter().for_each(|item| {
184+
assert!(mcv_freq.contains_key(item));
185+
assert_eq!(
186+
mcv_freq.get(item),
187+
total_frequencies
188+
.lock()
189+
.unwrap()
190+
.get(item)
191+
.map(|e| (*e as f64 / *total_count as f64))
192+
.as_ref()
193+
);
194+
});
195+
}
196+
}

optd-cost-model/src/stats/tdigest.rs

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,148 @@ where
248248
fn lerp(a: f64, b: f64, f: f64) -> f64 {
249249
(a * (1.0 - f)) + (b * f)
250250
}
251+
252+
#[cfg(test)]
253+
mod tests {
254+
use std::sync::{Arc, Mutex};
255+
256+
use crossbeam::thread;
257+
use ordered_float::OrderedFloat;
258+
use rand::distributions::{Distribution, Uniform, WeightedIndex};
259+
use rand::rngs::StdRng;
260+
use rand::SeedableRng;
261+
262+
use super::{IntoFloat, TDigest};
263+
264+
impl IntoFloat for OrderedFloat<f64> {
265+
fn to_float(&self) -> f64 {
266+
self.0
267+
}
268+
}
269+
270+
// Whether obtained = expected +/- error
271+
fn is_close(obtained: f64, expected: f64, error: f64) -> bool {
272+
((expected - error) < obtained) && (obtained < (expected + error))
273+
}
274+
275+
// Checks whether the tdigest follows a uniform distribution.
276+
fn check_tdigest_uniform(
277+
tdigest: &TDigest<OrderedFloat<f64>>,
278+
buckets: i32,
279+
max: f64,
280+
min: f64,
281+
error: f64,
282+
) {
283+
for k in 0..buckets {
284+
let expected_cdf = (k as f64) / (buckets as f64);
285+
let expected_quantile = (max - min) * expected_cdf + min;
286+
287+
let obtained_cdf = tdigest.cdf(&OrderedFloat(expected_quantile));
288+
let obtained_quantile = tdigest.quantile(expected_cdf);
289+
290+
assert!(is_close(obtained_cdf, expected_cdf, error));
291+
assert!(is_close(
292+
obtained_quantile,
293+
expected_quantile,
294+
(max - min) * error,
295+
));
296+
}
297+
}
298+
299+
#[test]
300+
fn uniform_merge_sequential() {
301+
let buckets = 200;
302+
let error = 0.03; // 3% absolute error on each quantile; error gets worse near the median.
303+
let mut tdigest = TDigest::new(buckets as f64);
304+
305+
let (min, max) = (-1000.0, 1000.0);
306+
let uniform_distr = Uniform::new(min, max);
307+
let mut rng = StdRng::seed_from_u64(0);
308+
309+
let batch_size = 1024;
310+
let batch_numbers = 64;
311+
312+
for _ in 0..batch_numbers {
313+
let mut random_numbers = Vec::with_capacity(batch_size);
314+
for _ in 0..batch_size {
315+
let num: f64 = uniform_distr.sample(&mut rng);
316+
random_numbers.push(OrderedFloat(num));
317+
}
318+
tdigest.merge_values(&random_numbers);
319+
}
320+
321+
check_tdigest_uniform(&tdigest, buckets, max, min, error);
322+
}
323+
324+
#[test]
325+
fn uniform_merge_parallel() {
326+
let buckets = 200;
327+
let error = 0.03; // 3% absolute error on each quantile, note error is worse near the median.
328+
329+
let (min, max) = (-1000.0, 1000.0);
330+
331+
let batch_size = 65536;
332+
let batch_numbers = 64;
333+
334+
let result_tdigest = Arc::new(Mutex::new(TDigest::new(buckets as f64)));
335+
thread::scope(|s| {
336+
for _ in 0..batch_numbers {
337+
s.spawn(|_| {
338+
let mut local_tdigest = TDigest::new(buckets as f64);
339+
340+
let mut random_numbers = Vec::with_capacity(batch_size);
341+
let uniform_distr = Uniform::new(min, max);
342+
let mut rng = StdRng::seed_from_u64(0);
343+
344+
for _ in 0..batch_size {
345+
let num: f64 = uniform_distr.sample(&mut rng);
346+
random_numbers.push(OrderedFloat(num));
347+
}
348+
local_tdigest.merge_values(&random_numbers);
349+
350+
let mut result = result_tdigest.lock().unwrap();
351+
result.merge(&local_tdigest);
352+
});
353+
}
354+
})
355+
.unwrap();
356+
357+
let tdigest = result_tdigest.lock().unwrap();
358+
check_tdigest_uniform(&tdigest, buckets, max, min, error);
359+
}
360+
361+
#[test]
362+
fn weighted_merge() {
363+
let buckets = 200;
364+
let error = 0.05; // 5% absolute error on each quantile, note error is worse near the median.
365+
366+
let mut tdigest = TDigest::new(buckets as f64);
367+
368+
let choices = [9.0, 900.0, 990.0, 9990.0, 190000.0, 990000.0];
369+
let weights = [1, 2, 1, 3, 4, 5]; // Total of 16.
370+
let total_weight: i32 = weights.iter().sum();
371+
372+
let weighted_distr = WeightedIndex::new(weights).unwrap();
373+
let mut rng = StdRng::seed_from_u64(0);
374+
375+
let batch_size = 128;
376+
let batch_numbers = 16;
377+
378+
for _ in 0..batch_numbers {
379+
let mut random_numbers = Vec::with_capacity(batch_size);
380+
for _ in 0..batch_size {
381+
let num: f64 = choices[weighted_distr.sample(&mut rng)];
382+
random_numbers.push(OrderedFloat(num));
383+
}
384+
tdigest.merge_values(&random_numbers);
385+
}
386+
387+
let mut curr_weight = 0;
388+
for (c, w) in choices.iter().zip(weights) {
389+
curr_weight += w;
390+
let estimate_cdf = tdigest.cdf(&OrderedFloat(*c));
391+
let obtained_cdf = (curr_weight as f64) / (total_weight as f64);
392+
assert!(is_close(obtained_cdf, estimate_cdf, error));
393+
}
394+
}
395+
}

0 commit comments

Comments
 (0)