Skip to content

Commit 36b8d86

Browse files
prashantgupta24joerundedeclark1
authored andcommitted
♻️ Default max sequence length based on model configuration
Based on precedence order: 1. MAX_SEQUENCE_LENGTH, if specified by the user 2. A field from config.json corresponding to the max sequence length 3. A default of 2048 Signed-off-by: Prashant Gupta <[email protected]> Signed-off-by: Joe Runde <[email protected]> Co-authored-by: Joe Runde <[email protected]> Co-authored-by: Daniel Clark <[email protected]>
1 parent 6403b0c commit 36b8d86

File tree

3 files changed

+77
-8
lines changed

3 files changed

+77
-8
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

launcher/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ description = "Text Generation Launcher"
99
clap = { version = "4.4.18", features = ["derive", "env"] }
1010
ctrlc = { version = "3.4.2", features = ["termination"] }
1111
nix = "0.27.1"
12+
serde_json = "^1.0.113"
1213
tracing = "0.1.40"
1314
tracing-subscriber = { version = "0.3.18", features = ["json"] }
1415

launcher/src/main.rs

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ use std::thread;
1313
use std::thread::sleep;
1414
use std::time::{Duration, Instant};
1515
use std::{fs, io};
16-
use std::fs::File;
1716
use std::env::VarError;
1817
use std::ffi::OsString;
18+
use std::fs::File;
1919
use std::os::unix::process::CommandExt;
20-
use String;
2120
use tracing::{info, warn};
2221

2322
// In most cases this gives the best performance for inferencing
2423
const DEFAULT_PYTORCH_CUDA_ALLOC_CONF: &'static str = "expandable_segments:True";
24+
const DEFAULT_MAX_SEQUENCE_LENGTH: usize = 2048;
2525

2626
/// App Configuration
2727
#[derive(Parser, Debug, Clone)]
@@ -43,8 +43,8 @@ struct Args {
4343
num_shard: Option<usize>,
4444
#[clap(default_value = "96", long, env)]
4545
max_concurrent_requests: usize,
46-
#[clap(default_value = "2048", long, env)]
47-
max_sequence_length: usize,
46+
#[clap(default_value = None, long, env)]
47+
max_sequence_length: Option<usize>,
4848
#[clap(default_value = "1024", long, env)]
4949
max_new_tokens: usize,
5050
#[clap(default_value = "12", long, env)]
@@ -117,14 +117,18 @@ fn main() -> ExitCode {
117117
),
118118
_ => (),
119119
}
120-
120+
121121
// Determine number of shards based on command line arg and env vars
122122
let num_shard = find_num_shards(args.num_shard);
123-
123+
124124
// Resolve fast tokenizer path
125125
let tokenizer_path = resolve_tokenizer_path(
126126
&args.model_name, args.revision.as_deref()
127127
).expect("Could not find tokenizer for model");
128+
129+
// Determine max sequence length based on command line arg and env vars
130+
let config_json_path = tokenizer_path.replace("tokenizer.json", "config.json");
131+
let max_sequence_length = get_max_sequence_length(args.max_sequence_length, &config_json_path);
128132

129133
match env::var("MAX_BATCH_WEIGHT") {
130134
Ok(max_batch_weight) if !max_batch_weight.trim().is_empty() => {
@@ -189,7 +193,7 @@ fn main() -> ExitCode {
189193
args.deployment_framework,
190194
args.dtype.or(args.dtype_str),
191195
args.quantize,
192-
args.max_sequence_length,
196+
max_sequence_length,
193197
args.max_new_tokens,
194198
args.max_batch_size,
195199
args.batch_safety_margin,
@@ -246,7 +250,7 @@ fn main() -> ExitCode {
246250
"--max-concurrent-requests".to_string(),
247251
args.max_concurrent_requests.to_string(),
248252
"--max-sequence-length".to_string(),
249-
args.max_sequence_length.to_string(),
253+
max_sequence_length.to_string(),
250254
"--max-new-tokens".to_string(),
251255
args.max_new_tokens.to_string(),
252256
"--max-batch-size".to_string(),
@@ -364,6 +368,69 @@ fn num_cuda_devices() -> Option<usize> {
364368
let n_devices = devices.split(',').count();
365369
Some(n_devices)
366370
}
371+
/// Finds a max sequence length for the model. In priority order:
372+
/// 1. MAX_SEQUENCE_LENGTH set by user
373+
/// 2. The sequence length specified in config.json
374+
/// 3. A default of 2048
375+
fn get_max_sequence_length(max_sequence_length: Option<usize>, config_json_path: &String) -> usize {
376+
if let Some(max_sequence_length) = max_sequence_length {
377+
info!(
378+
"Using configured max_sequence_length: {}",
379+
max_sequence_length
380+
);
381+
return max_sequence_length;
382+
}
383+
if let Ok(model_config) = get_config_json(config_json_path) {
384+
if let Some(length) = get_max_sequence_length_from_config(&model_config) {
385+
info!(
386+
"Loaded max_sequence_length from model config.json: {}",
387+
length
388+
);
389+
return length;
390+
}
391+
}
392+
info!(
393+
"Using default max_sequence_length: {}",
394+
DEFAULT_MAX_SEQUENCE_LENGTH
395+
);
396+
DEFAULT_MAX_SEQUENCE_LENGTH
397+
}
398+
399+
/// Opens the model's config.json file and reads into serde_json value
400+
fn get_config_json(config_path: &String) -> Result<serde_json::Value, std::io::Error> {
401+
let reader = BufReader::new(File::open(config_path)?);
402+
Ok(serde_json::from_reader(reader)?)
403+
}
404+
405+
/// Attempts to find a max sequence length from the model's config.
406+
/// There is no standard field for this, different model architectures name it differently.
407+
/// This checks for some well-known field names, and returns nothing if none are found.
408+
/// referenced from: https://github.com/vllm-project/vllm/blob/923797fea4d80a4dac4409ece3c450b84d5ba001/vllm/config.py#L557-L592
409+
fn get_max_sequence_length_from_config(model_config: &serde_json::Value) -> Option<usize> {
410+
let possible_keys = [
411+
// OPT
412+
"max_position_embeddings",
413+
// GPT-2
414+
"n_positions",
415+
// MPT
416+
"max_seq_len",
417+
// ChatGLM2
418+
"seq_length",
419+
// Others
420+
"max_sequence_length",
421+
"max_seq_length",
422+
"seq_len",
423+
];
424+
425+
for key in possible_keys {
426+
if let Some(value) = model_config.get(key) {
427+
if let Some(value) = value.as_u64() {
428+
return Some(value as usize);
429+
}
430+
}
431+
}
432+
None
433+
}
367434

368435
fn find_num_shards(num_shard: Option<usize>) -> usize {
369436
// get the number of shards given `num_gpu` and `num_shard`

0 commit comments

Comments
 (0)