|
| 1 | +use std::time::Duration; |
| 2 | + |
| 3 | +use tokio::time::Instant; |
| 4 | + |
| 5 | +/// A latency predictor using a numerically stable, exponentially decayed linear regression: |
| 6 | +/// |
| 7 | +/// We fit a model of the form: |
| 8 | +/// duration_secs ≈ base_time_secs + size_bytes * inv_throughput |
| 9 | +/// which is equivalent to: |
| 10 | +/// duration_secs ≈ intercept + slope * size_bytes |
| 11 | +/// |
| 12 | +/// Internally, we use a stable, online update method based on weighted means and covariances: |
| 13 | +/// - mean_x, mean_y: weighted means of size and duration |
| 14 | +/// - s_xx, s_xy: exponentially decayed sums of (x - mean_x)^2 and (x - mean_x)(y - mean_y) |
| 15 | +/// |
| 16 | +/// We apply decay on each update using exp2(-elapsed / half_life). |
| 17 | +/// |
| 18 | +/// This avoids numerical instability from large sums and is robust to shifting distributions. |
| 19 | +pub struct LatencyPredictor { |
| 20 | + sum_w: f64, |
| 21 | + mean_x: f64, |
| 22 | + mean_y: f64, |
| 23 | + s_xx: f64, |
| 24 | + s_xy: f64, |
| 25 | + |
| 26 | + base_time_secs: f64, |
| 27 | + inv_throughput: f64, |
| 28 | + decay_half_life_secs: f64, |
| 29 | + last_update: Instant, |
| 30 | +} |
| 31 | + |
| 32 | +impl LatencyPredictor { |
| 33 | + pub fn new(decay_half_life: Duration) -> Self { |
| 34 | + Self { |
| 35 | + sum_w: 0.0, |
| 36 | + mean_x: 0.0, |
| 37 | + mean_y: 0.0, |
| 38 | + s_xx: 0.0, |
| 39 | + s_xy: 0.0, |
| 40 | + base_time_secs: 120.0, // 2 minutes, but no real weight on this. |
| 41 | + inv_throughput: 0.0, |
| 42 | + decay_half_life_secs: decay_half_life.as_secs_f64(), |
| 43 | + last_update: Instant::now(), |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + /// Updates the latency model with a new observation. |
| 48 | + /// |
| 49 | + /// Applies exponential decay to prior statistics and incorporates the new sample |
| 50 | + /// using a numerically stable linear regression formula. |
| 51 | + /// |
| 52 | + /// - `size_bytes`: the size of the completed transmission. |
| 53 | + /// - `duration`: the time taken to complete the transmission. |
| 54 | + /// Updates the latency model with a new observation. |
| 55 | + /// |
| 56 | + /// Applies exponential decay to prior statistics and incorporates the new sample |
| 57 | + /// using a numerically stable linear regression formula. |
| 58 | + /// |
| 59 | + /// - `size_bytes`: the size of the completed transmission. |
| 60 | + /// - `duration`: the time taken to complete the transmission. |
| 61 | + /// - `n_concurrent`: the number of concurrent connections at the time. |
| 62 | + pub fn update(&mut self, size_bytes: usize, duration: Duration, avg_concurrent: f64) { |
| 63 | + let now = Instant::now(); |
| 64 | + let elapsed = now.duration_since(self.last_update).as_secs_f64(); |
| 65 | + let decay = (-elapsed / self.decay_half_life_secs).exp2(); |
| 66 | + |
| 67 | + // Feature x: number of bytes transferred in this time, assuming that multiple similar |
| 68 | + // connections are active. This is just a way to treat the |
| 69 | + let x = (size_bytes as f64) * avg_concurrent.max(1.); |
| 70 | + |
| 71 | + // Target y: the time it would take to transfer x bytes, i.e. secs / byte. |
| 72 | + let y = duration.as_secs_f64().max(1e-6); |
| 73 | + |
| 74 | + // Decay previous statistics |
| 75 | + self.sum_w *= decay; |
| 76 | + self.s_xx *= decay; |
| 77 | + self.s_xy *= decay; |
| 78 | + |
| 79 | + // Update means with numerically stable method |
| 80 | + let weight = 1.0; |
| 81 | + let new_sum_w = self.sum_w + weight; |
| 82 | + let delta_x = x - self.mean_x; |
| 83 | + let delta_y = y - self.mean_y; |
| 84 | + |
| 85 | + let mean_x_new = self.mean_x + (weight * delta_x) / new_sum_w; |
| 86 | + let mean_y_new = self.mean_y + (weight * delta_y) / new_sum_w; |
| 87 | + |
| 88 | + self.s_xx += weight * delta_x * (x - mean_x_new); |
| 89 | + self.s_xy += weight * delta_x * (y - mean_y_new); |
| 90 | + |
| 91 | + self.mean_x = mean_x_new; |
| 92 | + self.mean_y = mean_y_new; |
| 93 | + self.sum_w = new_sum_w; |
| 94 | + |
| 95 | + if self.s_xx > 1e-8 { |
| 96 | + let slope = self.s_xy / self.s_xx; |
| 97 | + let intercept = self.mean_y - slope * self.mean_x; |
| 98 | + |
| 99 | + self.base_time_secs = intercept; |
| 100 | + self.inv_throughput = slope; |
| 101 | + } else { |
| 102 | + self.base_time_secs = self.mean_y; |
| 103 | + self.inv_throughput = 0.0; |
| 104 | + } |
| 105 | + |
| 106 | + self.last_update = now; |
| 107 | + } |
| 108 | + |
| 109 | + /// Predicts the expected completion time for a given transfer size and concurrency level. |
| 110 | + /// |
| 111 | + /// First predicts the overall latency of a transfer, assuming that there is no concurrency and |
| 112 | + /// connections scale with |
| 113 | + /// |
| 114 | + /// to reflect how concurrency reduces per-transfer time under stable throughput. |
| 115 | + /// |
| 116 | + /// - `size_bytes`: the size of the transfer. |
| 117 | + /// - `n_concurrent`: the number of concurrent connections. |
| 118 | + pub fn predicted_latency(&self, size_bytes: u64, avg_concurrent: f64) -> Duration { |
| 119 | + let predicted_secs_without_concurrency = self.base_time_secs + size_bytes as f64 * self.inv_throughput; |
| 120 | + let predicted_secs = predicted_secs_without_concurrency * avg_concurrent.max(1.); |
| 121 | + Duration::from_secs_f64(predicted_secs) |
| 122 | + } |
| 123 | + |
| 124 | + pub fn predicted_bandwidth(&self) -> f64 { |
| 125 | + let query_bytes = 10 * 1024 * 1024; |
| 126 | + |
| 127 | + // How long would it take to transmit this at full bandwidth |
| 128 | + let min_latency = self.predicted_latency(query_bytes, 1.); |
| 129 | + |
| 130 | + // Report bytes per sec in this model. |
| 131 | + query_bytes as f64 / min_latency.as_secs_f64().max(1e-6) |
| 132 | + } |
| 133 | +} |
| 134 | + |
| 135 | +#[cfg(test)] |
| 136 | +mod tests { |
| 137 | + use tokio::time::{self, Duration as TokioDuration}; |
| 138 | + |
| 139 | + use super::*; |
| 140 | + |
| 141 | + #[test] |
| 142 | + fn test_estimator_update() { |
| 143 | + let mut estimator = LatencyPredictor::new(Duration::from_secs_f64(10.0)); |
| 144 | + estimator.update(1_000_000, Duration::from_millis(500), 1.); |
| 145 | + let expected = estimator.predicted_latency(1_000_000, 1.); |
| 146 | + assert!(expected.as_secs_f64() > 0.0); |
| 147 | + } |
| 148 | + |
| 149 | + #[test] |
| 150 | + fn test_converges_to_constant_observation() { |
| 151 | + let mut predictor = LatencyPredictor::new(Duration::from_secs_f64(10.0)); |
| 152 | + for _ in 0..10 { |
| 153 | + predictor.update(1000, Duration::from_secs_f64(1.0), 1.); |
| 154 | + } |
| 155 | + let prediction = predictor.predicted_latency(1000, 1.); |
| 156 | + assert!((prediction.as_secs_f64() - 1.0).abs() < 0.01); |
| 157 | + } |
| 158 | + |
| 159 | + #[tokio::test] |
| 160 | + async fn test_decay_weighting_effect() { |
| 161 | + time::pause(); |
| 162 | + let mut predictor = LatencyPredictor::new(Duration::from_secs_f64(2.0)); |
| 163 | + predictor.update(1000, Duration::from_secs_f64(2.0), 1.); |
| 164 | + time::advance(TokioDuration::from_secs(2)).await; |
| 165 | + predictor.update(1000, Duration::from_secs_f64(1.0), 1.); |
| 166 | + let predicted = predictor.predicted_latency(1000, 1.).as_secs_f64(); |
| 167 | + assert!(predicted > 1.0 && predicted < 1.6); |
| 168 | + } |
| 169 | + |
| 170 | + #[test] |
| 171 | + fn test_scaling_with_concurrency() { |
| 172 | + let mut predictor = LatencyPredictor::new(Duration::from_secs_f64(10.0)); |
| 173 | + for _ in 0..10 { |
| 174 | + predictor.update(1000, Duration::from_secs_f64(1.0), 1.); |
| 175 | + } |
| 176 | + let predicted_1 = predictor.predicted_latency(1000, 1.).as_secs_f64(); |
| 177 | + let predicted_2 = predictor.predicted_latency(1000, 2.).as_secs_f64(); |
| 178 | + let predicted_4 = predictor.predicted_latency(1000, 4.).as_secs_f64(); |
| 179 | + assert!(predicted_2 > predicted_1); |
| 180 | + assert!(predicted_4 > predicted_2); |
| 181 | + } |
| 182 | +} |
0 commit comments