Skip to content

Commit 34c44ff

Browse files
committed
fix: Enforce max_concurrent_requests > 0
We had allowed 0 to mean unlimited, but it's used to size a bounded queue. Bounded queue is preferable to unbounded for performance reasons.
1 parent e7af119 commit 34c44ff

File tree

3 files changed

+44
-30
lines changed

3 files changed

+44
-30
lines changed

launcher/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct Args {
4141
quantize: Option<String>,
4242
#[clap(long, env)]
4343
num_shard: Option<usize>,
44-
#[clap(default_value = "96", long, env)]
44+
#[clap(default_value = "512", long, env)]
4545
max_concurrent_requests: usize,
4646
#[clap(default_value = None, long, env)]
4747
max_sequence_length: Option<usize>,

router/src/main.rs

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use text_generation_router::server::ServerRunArgs;
1414
#[derive(Parser, Debug)]
1515
#[clap(author, version, about, long_about = None)]
1616
struct Args {
17-
#[clap(default_value = "96", long, env)]
17+
#[clap(default_value = "512", long, env)]
1818
max_concurrent_requests: usize,
1919
#[clap(default_value = "2048", long, env)]
2020
max_sequence_length: usize,
@@ -73,17 +73,8 @@ fn main() -> Result<(), std::io::Error> {
7373
tracing_subscriber::fmt().compact().init();
7474
}
7575

76-
if args.tokenization_workers == Some(0) {
77-
panic!("tokenization_workers must be > 0");
78-
}
79-
80-
if args.tls_key_path.is_some() != args.tls_cert_path.is_some() {
81-
panic!("tls: must provide both cert and key")
82-
}
83-
84-
if args.tls_client_ca_cert_path.is_some() && args.tls_cert_path.is_none() {
85-
panic!("tls: cannot provide client ca cert without keypair")
86-
}
76+
// Validate args
77+
validate_args(&args);
8778

8879
// Instantiate tokenizer
8980
let mut tokenizer = Tokenizer::from_file(args.tokenizer_path)
@@ -158,6 +149,42 @@ fn main() -> Result<(), std::io::Error> {
158149
})
159150
}
160151

152+
fn validate_args(args: &Args) {
153+
if args.tokenization_workers == Some(0) {
154+
panic!("tokenization_workers must be > 0");
155+
}
156+
157+
if args.max_concurrent_requests == 0 {
158+
panic!("max_concurrent_requests must be > 0");
159+
}
160+
161+
if args.tls_key_path.is_some() != args.tls_cert_path.is_some() {
162+
panic!("tls: must provide both cert and key")
163+
}
164+
165+
if args.tls_client_ca_cert_path.is_some() && args.tls_cert_path.is_none() {
166+
panic!("tls: cannot provide client ca cert without keypair")
167+
}
168+
169+
if args.max_prefill_padding < 0.0 || args.max_prefill_padding > 1.0 {
170+
panic!(
171+
"max_prefill_padding ({}) must be a percentage in the range [0.0, 1.0]",
172+
args.max_prefill_padding,
173+
)
174+
}
175+
176+
if args.max_new_tokens < 1 {
177+
panic!("max_new_tokens ({}) at least 1", args.max_new_tokens)
178+
}
179+
180+
if args.max_sequence_length < 2 {
181+
panic!(
182+
"max_sequence_length ({}) must be at least 2 (1 input + 1 output)",
183+
args.max_sequence_length,
184+
)
185+
}
186+
}
187+
161188
fn write_termination_log(msg: &str) -> Result<(), io::Error> {
162189
// Writes a message to the termination log.
163190
// Creates the logfile if it doesn't exist.

router/src/server.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ impl<'a, B: BatchType> BatchConfigValidator<'a, B> {
184184
fn validate_batch_config(
185185
&self,
186186
max_sequence_length: usize,
187-
max_batch_size: usize,
187+
_max_batch_size: usize,
188188
max_batch_weight: usize,
189189
) {
190190
let single_request_stats = <B>::update_stats(
@@ -284,30 +284,17 @@ async fn do_run<B: BatchType>(
284284
batch_weight_limit,
285285
);
286286

287-
let max_prefill_padding = args.max_prefill_padding;
288-
if max_prefill_padding < 0.0 || max_prefill_padding > 1.0 {
289-
panic!("max_prefill_padding ({}) must be a percentage in the range [0.0, 1.0]", max_prefill_padding)
290-
}
291-
292-
if args.max_new_tokens < 1 {
293-
panic!("max_new_tokens ({}) at least 1", args.max_new_tokens)
294-
}
295-
296-
if args.max_sequence_length < 2 {
297-
panic!("max_sequence_length ({}) must be at least 2 (1 input + 1 output)", args.max_sequence_length)
298-
}
299-
300287
let max_new_tokens = if args.max_new_tokens < args.max_sequence_length {
301288
args.max_new_tokens
302289
} else {
303-
tracing::warn!(
290+
warn!(
304291
"adjusting max_new_tokens ({}) down to max_sequence_length - 1 ({})",
305292
args.max_new_tokens,
306293
args.max_sequence_length-1
307294
);
308295
args.max_sequence_length - 1
309296
};
310-
297+
311298

312299
let tokenizers = AsyncTokenizer::new(
313300
&args.tokenizer, args.tokenization_workers
@@ -326,7 +313,7 @@ async fn do_run<B: BatchType>(
326313
BatchingConfig {
327314
size_limit: args.max_batch_size,
328315
weight_limit: batch_weight_limit,
329-
prefill_padding_limit: max_prefill_padding,
316+
prefill_padding_limit: args.max_prefill_padding,
330317
},
331318
args.max_waiting_tokens,
332319
args.max_concurrent_requests,

0 commit comments

Comments
 (0)