Skip to content

Commit c35f7d2

Browse files
committed
Updates.
1 parent 9dd1b3b commit c35f7d2

File tree

2 files changed

+60
-21
lines changed

2 files changed

+60
-21
lines changed

cas_client/src/adaptive_concurrency/controller.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,19 +150,35 @@ impl AdaptiveConcurrencyController {
150150
.as_secs_f64()
151151
.max(1e-4);
152152

153-
(t_actual / t_pred).min(*CONCURRENCY_CONTROL_FAILURE_DEVIANCE_PENALTY)
153+
let dev_ratio = (t_actual / t_pred).min(*CONCURRENCY_CONTROL_FAILURE_DEVIANCE_PENALTY);
154+
155+
eprintln!(
156+
"success = {is_succes}; t_pred = {t_pred}; t_actual = {t_actual}; dev_ratio = {dev_ratio}"
157+
);
158+
159+
state_lg
160+
.latency_predictor
161+
.update(n_bytes, actual_completion_time, avg_concurrency);
162+
163+
dev_ratio
154164
} else {
165+
eprintln!("failure, bytes known.");
166+
155167
// If it's not a success, then update the deviance with the penalty factor.
156168
*CONCURRENCY_CONTROL_FAILURE_DEVIANCE_PENALTY
157169
}
158170
} else {
159171
// This would be a failure case, so update the
160172
debug_assert!(!is_succes);
161173

174+
eprintln!("failure, bytes unknown.");
175+
162176
*CONCURRENCY_CONTROL_FAILURE_DEVIANCE_PENALTY
163177
}
164178
};
165179

180+
eprintln!("dev_ratio = {}; ln = {}", deviance_ratio, deviance_ratio.ln());
181+
166182
// Update the deviance with this value; we're tracking the log of the ratio due
167183
// to the additive averaging.
168184
state_lg.prediction_deviance.update(deviance_ratio.ln());
@@ -202,7 +218,7 @@ impl AdaptiveConcurrencyController {
202218
self.logging_tag,
203219
self.concurrency_semaphore.total_permits(),
204220
state_lg.latency_predictor.predicted_bandwidth(),
205-
state_lg.prediction_deviance.value()
221+
state_lg.prediction_deviance.value().exp()
206222
);
207223
}
208224
}

cas_client/src/adaptive_concurrency/latency_prediction.rs

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ pub struct LatencyPredictor {
2929
last_update: Instant,
3030
}
3131

32+
// Use MB for the scale of the size; this is more numerically stable.
33+
const BASE_SIZE_UNIT: f64 = 1_000_000.;
34+
3235
impl LatencyPredictor {
3336
pub fn new(decay_half_life: Duration) -> Self {
3437
Self {
@@ -37,7 +40,7 @@ impl LatencyPredictor {
3740
mean_y: 0.0,
3841
s_xx: 0.0,
3942
s_xy: 0.0,
40-
base_time_secs: 120.0, // 2 minutes, but no real weight on this.
43+
base_time_secs: 0.,
4144
inv_throughput: 0.0,
4245
decay_half_life_secs: decay_half_life.as_secs_f64(),
4346
last_update: Instant::now(),
@@ -59,42 +62,44 @@ impl LatencyPredictor {
5962
/// - `size_bytes`: the size of the completed transmission.
6063
/// - `duration`: the time taken to complete the transmission.
6164
/// - `n_concurrent`: the number of concurrent connections at the time.
62-
pub fn update(&mut self, size_bytes: usize, duration: Duration, avg_concurrent: f64) {
65+
pub fn update(&mut self, size_bytes: u64, duration: Duration, avg_concurrent: f64) {
6366
let now = Instant::now();
6467
let elapsed = now.duration_since(self.last_update).as_secs_f64();
6568
let decay = (-elapsed / self.decay_half_life_secs).exp2();
6669

6770
// Feature x: number of bytes transferred in this time, assuming that multiple similar
6871
// connections are active. This is just a way to treat the
69-
let x = (size_bytes as f64) * avg_concurrent.max(1.);
72+
let x = (size_bytes as f64) / BASE_SIZE_UNIT;
7073

7174
// Target y: the time it would take to transfer x bytes, i.e. secs / byte.
72-
let y = duration.as_secs_f64().max(1e-6);
75+
let y = duration.as_secs_f64().max(1e-6) / avg_concurrent.max(1.);
7376

7477
// Decay previous statistics
7578
self.sum_w *= decay;
7679
self.s_xx *= decay;
7780
self.s_xy *= decay;
7881

7982
// Update means with numerically stable method
80-
let weight = 1.0;
81-
let new_sum_w = self.sum_w + weight;
83+
let obs_weight = 1.0;
84+
let new_sum_w = self.sum_w + obs_weight;
8285
let delta_x = x - self.mean_x;
8386
let delta_y = y - self.mean_y;
8487

85-
let mean_x_new = self.mean_x + (weight * delta_x) / new_sum_w;
86-
let mean_y_new = self.mean_y + (weight * delta_y) / new_sum_w;
88+
let mean_x_new = self.mean_x + (obs_weight * delta_x) / new_sum_w;
89+
let mean_y_new = self.mean_y + (obs_weight * delta_y) / new_sum_w;
8790

88-
self.s_xx += weight * delta_x * (x - mean_x_new);
89-
self.s_xy += weight * delta_x * (y - mean_y_new);
91+
self.s_xx += obs_weight * delta_x * (x - mean_x_new);
92+
self.s_xy += obs_weight * delta_x * (y - mean_y_new);
9093

9194
self.mean_x = mean_x_new;
9295
self.mean_y = mean_y_new;
9396
self.sum_w = new_sum_w;
9497

9598
if self.s_xx > 1e-8 {
96-
let slope = self.s_xy / self.s_xx;
97-
let intercept = self.mean_y - slope * self.mean_x;
99+
// Negative slopes or intercept here isn't meaningful and can cause negative predicted durations,
100+
// so clamp the slope to 0.
101+
let slope = (self.s_xy / self.s_xx).max(0.);
102+
let intercept = (self.mean_y - slope * self.mean_x).max(0.);
98103

99104
self.base_time_secs = intercept;
100105
self.inv_throughput = slope;
@@ -103,6 +108,15 @@ impl LatencyPredictor {
103108
self.inv_throughput = 0.0;
104109
}
105110

111+
eprintln!(
112+
"x = {x}; y = {y}; mean_x = {}, mean_y = {}, intercept = {}, slope = {}, bw = {}",
113+
self.mean_x,
114+
self.mean_y,
115+
self.base_time_secs,
116+
self.inv_throughput,
117+
self.predicted_bandwidth()
118+
);
119+
106120
self.last_update = now;
107121
}
108122

@@ -116,8 +130,15 @@ impl LatencyPredictor {
116130
/// - `size_bytes`: the size of the transfer.
117131
/// - `n_concurrent`: the number of concurrent connections.
118132
pub fn predicted_latency(&self, size_bytes: u64, avg_concurrent: f64) -> Duration {
119-
let predicted_secs_without_concurrency = self.base_time_secs + size_bytes as f64 * self.inv_throughput;
120-
let predicted_secs = predicted_secs_without_concurrency * avg_concurrent.max(1.);
133+
// Feature x: number of bytes transferred in this time, assuming that multiple similar
134+
// connections are active. This is just a way to treat the
135+
let x = (size_bytes as f64) / BASE_SIZE_UNIT;
136+
137+
let y_pred = self.base_time_secs + self.inv_throughput * x;
138+
139+
debug_assert!(y_pred > 0.);
140+
141+
let predicted_secs = (y_pred * avg_concurrent.max(1.)).max(0.);
121142
Duration::from_secs_f64(predicted_secs)
122143
}
123144

@@ -148,12 +169,14 @@ mod tests {
148169

149170
#[test]
150171
fn test_converges_to_constant_observation() {
151-
let mut predictor = LatencyPredictor::new(Duration::from_secs_f64(10.0));
152-
for _ in 0..10 {
153-
predictor.update(1000, Duration::from_secs_f64(1.0), 1.);
172+
for concurrency in [1., 5., 100.] {
173+
let mut predictor = LatencyPredictor::new(Duration::from_secs_f64(10.0));
174+
for _ in 0..10 {
175+
predictor.update(1000, Duration::from_secs_f64(1.0), concurrency);
176+
}
177+
let prediction = predictor.predicted_latency(1000, concurrency);
178+
assert!((prediction.as_secs_f64() - 1.0).abs() < 0.01);
154179
}
155-
let prediction = predictor.predicted_latency(1000, 1.);
156-
assert!((prediction.as_secs_f64() - 1.0).abs() < 0.01);
157180
}
158181

159182
#[tokio::test]

0 commit comments

Comments
 (0)