Skip to content

Commit b0d32ef

Browse files
committed
Enforce limit on prefill padding tokens, delay between prefills
To limit computation wasted on padding and mitigate impact of frequent large-input, tiny-output request workloads.
1 parent 83b66c5 commit b0d32ef

File tree

7 files changed

+141
-78
lines changed

7 files changed

+141
-78
lines changed

README.md

Lines changed: 39 additions & 38 deletions
Large diffs are not rendered by default.

launcher/src/main.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ struct Args {
4949
max_batch_weight: Option<usize>,
5050
#[clap(default_value = None, long, env)]
5151
max_prefill_weight: Option<usize>,
52+
#[clap(default_value = "0.2", long, env)]
53+
max_prefill_padding: f32,
5254
#[clap(default_value = "24", long, env)]
5355
max_waiting_tokens: usize,
5456
#[clap(default_value = "3000", long, short, env)]
@@ -221,6 +223,8 @@ fn main() -> ExitCode {
221223
args.max_new_tokens.to_string(),
222224
"--max-batch-size".to_string(),
223225
args.max_batch_size.to_string(),
226+
"--max-prefill-padding".to_string(),
227+
args.max_prefill_padding.to_string(),
224228
"--max-waiting-tokens".to_string(),
225229
args.max_waiting_tokens.to_string(),
226230
"--port".to_string(),

router/src/batch_types.rs

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ pub(crate) trait BatchType: Send + Sync + Clone + 'static {
1515
fn batch_initial_weight(stats: &Self::Stats, batch_size: usize) -> usize;
1616
/// Calculate prefill batch weight given prefill batch statistics
1717
fn prefill_weight(prefill_stats: &Self::Stats, batch_size: usize) -> usize;
18+
/// Percentage of batch tokens that are padding
19+
fn percent_padding(prefill_stats: &Self::Stats, batch_size: usize) -> f32;
1820
/// Indicate whether a hypothetical batch will exceed the combined weight limit
1921
fn exceeds_weight(
2022
tree: &BTreeSet<(usize, usize, usize)>, max_total_weight: usize, current_output_len: usize
@@ -61,16 +63,18 @@ impl BatchType for FlashBatch {
6163
total_in_tokens + total_out_tokens
6264
}
6365

64-
fn batch_initial_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
65-
let (total_in_tokens, _) = total_tokens;
66+
fn batch_initial_weight((total_in_tokens, _): &Self::Stats, _batch_size: usize) -> usize {
6667
*total_in_tokens
6768
}
6869

69-
fn prefill_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
70-
let (total_in_tokens, _) = total_tokens;
70+
fn prefill_weight((total_in_tokens, _): &Self::Stats, _batch_size: usize) -> usize {
7171
*total_in_tokens
7272
}
7373

74+
fn percent_padding(_: &Self::Stats, _batch_size: usize) -> f32 {
75+
0.0
76+
}
77+
7478
fn exceeds_weight(
7579
tree: &BTreeSet<(usize, usize, usize)>, max_total_weight: usize, current_output_len: usize
7680
) -> bool {
@@ -106,34 +110,44 @@ impl BatchType for FlashBatch {
106110
pub(crate) struct PaddedBatch {}
107111

108112
impl BatchType for PaddedBatch {
109-
/// Keep track of maximum input length, maximum output length
110-
type Stats = (usize, usize);
113+
/// Keep track of maximum input length, maximum output length, input token count
114+
type Stats = (usize, usize, usize);
111115

112116
fn update_stats(
113117
max_in_out_lengths: &Self::Stats, input_length: usize, output_length: usize
114118
) -> Self::Stats {
115-
let (max_input_length, max_output_length) = max_in_out_lengths;
116-
(max(*max_input_length, input_length), max(*max_output_length, output_length))
119+
let (max_input_length, max_output_length, total_in_tokens) = max_in_out_lengths;
120+
(
121+
max(*max_input_length, input_length),
122+
max(*max_output_length, output_length),
123+
total_in_tokens + input_length
124+
)
117125
}
118126

119127
fn batch_max_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
120-
let (max_input_length, max_output_length) = max_in_out_lengths;
128+
let (max_input_length, max_output_length, _) = max_in_out_lengths;
121129
let max_seq_len = max_input_length + max_output_length;
122130
// Memory requirement roughly proportional to batch_size * seq_len^2
123131
batch_size * max_seq_len.pow(2)
124132
}
125133

126-
fn batch_initial_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
127-
let (max_input_length, _) = max_in_out_lengths;
134+
fn batch_initial_weight((max_input_length, _, _): &Self::Stats, batch_size: usize) -> usize {
128135
batch_size * max_input_length.pow(2)
129136
}
130137

131-
fn prefill_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
138+
fn prefill_weight((max_input_length, _, _): &Self::Stats, batch_size: usize) -> usize {
132139
// Empirically, prefill latency is proportional to batch_size * seq_len^(3/2)
133-
let (max_input_length, _) = max_in_out_lengths;
134140
batch_size * max_input_length.pow(3).sqrt()
135141
}
136142

143+
fn percent_padding((max_input_length, _, total_in_tokens): &Self::Stats, batch_size: usize) -> f32 {
144+
let total_toks = max_input_length * batch_size;
145+
match total_toks {
146+
0 => 0.0,
147+
total_toks => (total_toks - total_in_tokens) as f32 / total_toks as f32,
148+
}
149+
}
150+
137151
fn exceeds_weight(
138152
tree: &BTreeSet<(usize, usize, usize)>, max_total_weight: usize, current_output_len: usize
139153
) -> bool {

router/src/batcher.rs

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@ use axum::http::StatusCode;
66
use axum::Json;
77
use std::future::Future;
88
use std::mem::take;
9+
use std::ops::Add;
910
use std::pin::Pin;
1011
use std::sync::Arc;
1112
use std::sync::atomic::{AtomicBool, Ordering};
1213
use std::task::{Context, Poll};
14+
use std::time::Duration;
1315
use futures::{FutureExt, pin_mut, TryFutureExt};
1416
use futures::future::Map;
1517
use nohash_hasher::IntMap;
16-
use text_generation_client::{ClientError, Token, ShardedClient, CachedBatch, RequestsStatus, InputTokens, GenerateError, Batch, GenerateTokenResponse};
18+
use text_generation_client::{
19+
ClientError, Token, ShardedClient, CachedBatch, RequestsStatus,
20+
InputTokens, GenerateError, Batch, GenerateTokenResponse
21+
};
1722
use thiserror::Error;
1823
use tokio::select;
1924

@@ -355,11 +360,12 @@ async fn batching_task<B: BatchType>(
355360
}
356361
log_new_batch(batch.id, processor.entries());
357362

358-
let mut cached_batch = processor.prefill(
363+
let (mut cached_batch, _) = processor.prefill(
359364
&mut client, batch, vec![], None, &mut queue,
360365
).await;
361366
let mut waiting_tokens = 1;
362367
let mut batch_max_remaining_tokens = None;
368+
let mut next_prefill_after = None;
363369

364370
// We loop until we do not receive any cached batch from the inference server (== until
365371
// all requests have met their stopping criteria)
@@ -385,7 +391,8 @@ async fn batching_task<B: BatchType>(
385391
metrics::gauge!("tgi_batch_max_remaining_tokens", batch_max_remaining_tokens.unwrap() as f64);
386392

387393
// Don't interfere with current batch if it's about to complete
388-
if batch_max_remaining_tokens.unwrap() >= 2 {
394+
if batch_max_remaining_tokens.unwrap() >= 2 &&
395+
next_prefill_after.map_or(true, |t| Instant::now() > t) {
389396
// Determine min num of requests for add-on batch based on current batch size and
390397
// tokens since last prefill
391398
let min_size = if batch_size <= 1 || waiting_tokens >= max_waiting_tokens {
@@ -411,7 +418,7 @@ async fn batching_task<B: BatchType>(
411418
// Generate one token for this new batch to have the attention past in cache
412419
let first_new_id = new_batch.requests.first()
413420
.expect("Batch can't be empty here").id;
414-
let new_cached_batch = processor.prefill(
421+
let (new_cached_batch, prefill_time) = processor.prefill(
415422
&mut client, new_batch, to_prune, Some(first_new_id), &mut queue
416423
).await;
417424

@@ -424,6 +431,9 @@ async fn batching_task<B: BatchType>(
424431
// Reset waiting counter and batch_remaining_tokens
425432
waiting_tokens = 1;
426433
batch_max_remaining_tokens = None;
434+
// Ensure we wait at least half as long as the last prefill took
435+
// before we do another prefill (unless the entire batch completes by then)
436+
next_prefill_after = Some(Instant::now().add(prefill_time / 2));
427437
// Extend current batch with the new batch
428438
if let Some(new_batch) = new_cached_batch {
429439
let new_batch_id = new_batch.batch_id;
@@ -452,10 +462,12 @@ async fn batching_task<B: BatchType>(
452462
// All batches completed or failed, fetch a new one
453463
break
454464
}
465+
} else {
466+
next_prefill_after = None;
455467
}
456468
}
457469

458-
cached_batch = processor.next_token(&mut client, batches, &mut queue).await;
470+
(cached_batch, _) = processor.next_token(&mut client, batches, &mut queue).await;
459471
waiting_tokens += 1;
460472
// Reset batch_remaining_tokens if any requests in the batch completed
461473
if batch_max_remaining_tokens.is_some() && some_completed(&cached_batch) {
@@ -520,29 +532,24 @@ impl<'a> TokenProcessor<'a> {
520532
// First request id in this batch if it doesn't comprise all current entries
521533
start_id: Option<u64>,
522534
queue: &mut Queue<B>,
523-
) -> Option<CachedBatch> {
535+
) -> (Option<CachedBatch>, Duration) {
524536
let batch_size = batch.requests.len();
525537
let batch_tokens = batch.total_tokens;
526538
let start_time = Instant::now();
527539
metrics::histogram!("tgi_batch_next_tokens", batch_tokens as f64);
528540
metrics::histogram!(
529541
"tgi_batch_inference_batch_size", batch_size as f64, "method" => "prefill"
530542
);
531-
self._wrap_future(
532-
client.prefill(batch, to_prune).map(|r| {
533-
info!(
534-
"Prefill took {:?} for {batch_size} inputs, {batch_tokens} total tokens",
535-
start_time.elapsed(),
536-
);
537-
r
538-
}),
539-
"prefill", start_time, start_id, queue
540-
).await
543+
let (result, prefill_time) = self._wrap_future(
544+
client.prefill(batch, to_prune), "prefill", start_time, start_id, queue
545+
).await;
546+
info!("Prefill took {prefill_time:?} for {batch_size} inputs, {batch_tokens} total tokens");
547+
(result, prefill_time)
541548
}
542549

543550
async fn next_token<B: BatchType>(
544551
&mut self, client: &mut ShardedClient, batches: Vec<CachedBatch>, queue: &mut Queue<B>,
545-
) -> Option<CachedBatch> {
552+
) -> (Option<CachedBatch>, Duration) {
546553
metrics::histogram!(
547554
"tgi_batch_inference_batch_size", self.entries.len() as f64, "method" => "next_token"
548555
);
@@ -561,7 +568,7 @@ impl<'a> TokenProcessor<'a> {
561568
// First request id in this batch if it doesn't comprise all current entries
562569
start_id: Option<u64>,
563570
queue: &mut Queue<B>,
564-
) -> Option<CachedBatch> {
571+
) -> (Option<CachedBatch>, Duration) {
565572
metrics::increment_counter!("tgi_batch_inference_count", "method" => method);
566573

567574
// We process the shared queue while waiting for the response from the python shard(s)
@@ -574,7 +581,8 @@ impl<'a> TokenProcessor<'a> {
574581
}
575582
};
576583

577-
match result {
584+
let elapsed = start_time.elapsed();
585+
let result = match result {
578586
Ok(
579587
Some((generated_tokens, input_tokens, errors, next_batch_id, forward_duration))
580588
) => {
@@ -587,7 +595,7 @@ impl<'a> TokenProcessor<'a> {
587595
self.generation_health.store(true, Ordering::SeqCst);
588596
metrics::histogram!(
589597
"tgi_batch_inference_duration",
590-
start_time.elapsed().as_secs_f64(),
598+
elapsed.as_secs_f64(),
591599
"method" => method,
592600
"makeup" => "single_only", // later will possibly be beam_only or mixed
593601
);
@@ -626,7 +634,9 @@ impl<'a> TokenProcessor<'a> {
626634
self.send_errors(err, start_id);
627635
None
628636
},
629-
}
637+
};
638+
639+
(result, elapsed)
630640
}
631641

632642
/// Send errors to the Batcher for all `request_ids`

router/src/main.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ struct Args {
2323
max_batch_weight: Option<usize>,
2424
#[clap(default_value = None, long, env)]
2525
max_prefill_weight: Option<usize>,
26+
#[clap(default_value = "0.2", long, env)]
27+
max_prefill_padding: f32,
2628
#[clap(default_value = "24", long, env)]
2729
max_waiting_tokens: usize,
2830
#[clap(default_value = "3000", long, short, env)]
@@ -129,6 +131,7 @@ fn main() -> Result<(), std::io::Error> {
129131
max_batch_size: args.max_batch_size,
130132
max_batch_weight: args.max_batch_weight,
131133
max_prefill_weight: args.max_prefill_weight,
134+
max_prefill_padding: args.max_prefill_padding,
132135
max_waiting_tokens: args.max_waiting_tokens,
133136
client: sharded_client,
134137
tokenizer,

router/src/queue.rs

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ pub(crate) struct BatchingConfig {
105105
pub(crate) weight_limit: usize,
106106
/// Maximum weight of individual prefill batches
107107
pub(crate) prefill_weight_limit: usize,
108+
/// Maximum percentage of pad tokens in prefill batches. In range [0, 1]
109+
pub(crate) prefill_padding_limit: f32,
108110
}
109111

110112
/// Request Queue
@@ -249,9 +251,15 @@ impl<B: BatchType> Queue<B> {
249251
let pct_space_free = 1.0 - (
250252
current_batch_weight as f64 / self.config.weight_limit as f64
251253
);
252-
(pct_space_free * prefill_limit as f64) as usize
254+
let limit = (pct_space_free * prefill_limit as f64) as usize;
255+
if limit == 0 {
256+
return None
257+
}
258+
limit
253259
},
254260
};
261+
let max_prefill_padding = self.config.prefill_padding_limit;
262+
255263
// We first do a read-only pass over the queue to allow skipping over large entries
256264
// that don't fit in the current batch to reach smaller entries that do
257265
for (index, entry) in self.buffer.iter().enumerate() {
@@ -316,14 +324,29 @@ impl<B: BatchType> Queue<B> {
316324
}
317325

318326
// Also check whether adding this request will breach the prefill weight limit
319-
if effective_prefill_weight_limit > 0 {
327+
if effective_prefill_weight_limit > 0 || max_prefill_padding < 1.0 {
320328
let next_prefill_stats = <B>::update_stats(
321329
&prefill_stats, input_len, 0
322330
);
323-
let prefill_weight = <B>::prefill_weight(
324-
&next_prefill_stats, chosen_indices.len() + 1
325-
);
326-
if prefill_weight > effective_prefill_weight_limit {
331+
let batch_size = chosen_indices.len() + 1;
332+
let mut skip = false;
333+
if effective_prefill_weight_limit > 0 {
334+
let prefill_weight = <B>::prefill_weight(&next_prefill_stats, batch_size);
335+
if prefill_weight > effective_prefill_weight_limit {
336+
skip = true;
337+
metrics::increment_counter!("tgi_prefill_weight_limit_exceeded");
338+
}
339+
}
340+
if !skip && max_prefill_padding < 1.0 {
341+
let percentage_padding = <B>::percent_padding(&next_prefill_stats, batch_size);
342+
if percentage_padding > max_prefill_padding {
343+
skip = true;
344+
//TODO if we skip due to padding and added other requests from queue,
345+
// we could consider doing another pass since the padding proportion may have decreased
346+
metrics::increment_counter!("tgi_prefill_padding_limit_exceeded");
347+
}
348+
}
349+
if skip {
327350
if let Some(tree) = btree.as_mut() {
328351
// Remove our tuple from the set
329352
tree.remove(&(output_len, input_len, tree.len() - 1));

router/src/server.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ pub struct ServerRunArgs {
245245
pub max_batch_size: usize,
246246
pub max_batch_weight: Option<usize>,
247247
pub max_prefill_weight: Option<usize>,
248+
pub max_prefill_padding: f32,
248249
pub max_waiting_tokens: usize,
249250
pub client: ShardedClient,
250251
pub tokenizer: Tokenizer,
@@ -273,6 +274,7 @@ pub async fn run(mut args: ServerRunArgs) {
273274
if use_padding {
274275
do_run(args, seq2seq, eos_token_id, PaddedBatch{}).await
275276
} else {
277+
args.max_prefill_padding = 1.0; // There's no padding so disable checking for this
276278
do_run(args, seq2seq, eos_token_id, FlashBatch{}).await
277279
}
278280
}
@@ -294,6 +296,11 @@ async fn do_run<B: BatchType>(
294296
args.max_prefill_weight,
295297
);
296298

299+
let max_prefill_padding = args.max_prefill_padding;
300+
if max_prefill_padding < 0.0 || max_prefill_padding > 1.0 {
301+
panic!("max_prefill_padding ({}) must be a percentage in the range [0.0, 1.0]", max_prefill_padding)
302+
}
303+
297304
let tokenizers = AsyncTokenizer::new(
298305
&args.tokenizer, args.tokenization_workers
299306
);
@@ -312,6 +319,7 @@ async fn do_run<B: BatchType>(
312319
size_limit: args.max_batch_size,
313320
weight_limit: max_batch_weight,
314321
prefill_weight_limit: max_prefill_weight,
322+
prefill_padding_limit: max_prefill_padding,
315323
},
316324
args.max_waiting_tokens,
317325
args.max_concurrent_requests,

0 commit comments

Comments
 (0)