@@ -13,15 +13,15 @@ use std::thread;
13
13
use std:: thread:: sleep;
14
14
use std:: time:: { Duration , Instant } ;
15
15
use std:: { fs, io} ;
16
- use std:: fs:: File ;
17
16
use std:: env:: VarError ;
18
17
use std:: ffi:: OsString ;
18
+ use std:: fs:: File ;
19
19
use std:: os:: unix:: process:: CommandExt ;
20
- use String ;
21
20
use tracing:: { info, warn} ;
22
21
23
22
// In most cases this gives the best performance for inferencing
24
23
const DEFAULT_PYTORCH_CUDA_ALLOC_CONF : & ' static str = "expandable_segments:True" ;
24
+ const DEFAULT_MAX_SEQUENCE_LENGTH : usize = 2048 ;
25
25
26
26
/// App Configuration
27
27
#[ derive( Parser , Debug , Clone ) ]
@@ -43,8 +43,8 @@ struct Args {
43
43
num_shard : Option < usize > ,
44
44
#[ clap( default_value = "96" , long, env) ]
45
45
max_concurrent_requests : usize ,
46
- #[ clap( default_value = "2048" , long, env) ]
47
- max_sequence_length : usize ,
46
+ #[ clap( default_value = None , long, env) ]
47
+ max_sequence_length : Option < usize > ,
48
48
#[ clap( default_value = "1024" , long, env) ]
49
49
max_new_tokens : usize ,
50
50
#[ clap( default_value = "12" , long, env) ]
@@ -117,14 +117,18 @@ fn main() -> ExitCode {
117
117
) ,
118
118
_ => ( ) ,
119
119
}
120
-
120
+
121
121
// Determine number of shards based on command line arg and env vars
122
122
let num_shard = find_num_shards ( args. num_shard ) ;
123
-
123
+
124
124
// Resolve fast tokenizer path
125
125
let tokenizer_path = resolve_tokenizer_path (
126
126
& args. model_name , args. revision . as_deref ( )
127
127
) . expect ( "Could not find tokenizer for model" ) ;
128
+
129
+ // Determine max sequence length based on command line arg and env vars
130
+ let config_json_path = tokenizer_path. replace ( "tokenizer.json" , "config.json" ) ;
131
+ let max_sequence_length = get_max_sequence_length ( args. max_sequence_length , & config_json_path) ;
128
132
129
133
match env:: var ( "MAX_BATCH_WEIGHT" ) {
130
134
Ok ( max_batch_weight) if !max_batch_weight. trim ( ) . is_empty ( ) => {
@@ -189,7 +193,7 @@ fn main() -> ExitCode {
189
193
args. deployment_framework ,
190
194
args. dtype . or ( args. dtype_str ) ,
191
195
args. quantize ,
192
- args . max_sequence_length ,
196
+ max_sequence_length,
193
197
args. max_new_tokens ,
194
198
args. max_batch_size ,
195
199
args. batch_safety_margin ,
@@ -246,7 +250,7 @@ fn main() -> ExitCode {
246
250
"--max-concurrent-requests" . to_string( ) ,
247
251
args. max_concurrent_requests. to_string( ) ,
248
252
"--max-sequence-length" . to_string( ) ,
249
- args . max_sequence_length. to_string( ) ,
253
+ max_sequence_length. to_string( ) ,
250
254
"--max-new-tokens" . to_string( ) ,
251
255
args. max_new_tokens. to_string( ) ,
252
256
"--max-batch-size" . to_string( ) ,
@@ -364,6 +368,69 @@ fn num_cuda_devices() -> Option<usize> {
364
368
let n_devices = devices. split ( ',' ) . count ( ) ;
365
369
Some ( n_devices)
366
370
}
371
+ /// Finds a max sequence length for the model. In priority order:
372
+ /// 1. MAX_SEQUENCE_LENGTH set by user
373
+ /// 2. The sequence length specified in config.json
374
+ /// 3. A default of 2048
375
+ fn get_max_sequence_length ( max_sequence_length : Option < usize > , config_json_path : & String ) -> usize {
376
+ if let Some ( max_sequence_length) = max_sequence_length {
377
+ info ! (
378
+ "Using configured max_sequence_length: {}" ,
379
+ max_sequence_length
380
+ ) ;
381
+ return max_sequence_length;
382
+ }
383
+ if let Ok ( model_config) = get_config_json ( config_json_path) {
384
+ if let Some ( length) = get_max_sequence_length_from_config ( & model_config) {
385
+ info ! (
386
+ "Loaded max_sequence_length from model config.json: {}" ,
387
+ length
388
+ ) ;
389
+ return length;
390
+ }
391
+ }
392
+ info ! (
393
+ "Using default max_sequence_length: {}" ,
394
+ DEFAULT_MAX_SEQUENCE_LENGTH
395
+ ) ;
396
+ DEFAULT_MAX_SEQUENCE_LENGTH
397
+ }
398
+
399
+ /// Opens the model's config.json file and reads into serde_json value
400
+ fn get_config_json ( config_path : & String ) -> Result < serde_json:: Value , std:: io:: Error > {
401
+ let reader = BufReader :: new ( File :: open ( config_path) ?) ;
402
+ Ok ( serde_json:: from_reader ( reader) ?)
403
+ }
404
+
405
+ /// Attempts to find a max sequence length from the model's config.
406
+ /// There is no standard field for this, different model architectures name it differently.
407
+ /// This checks for some well-known field names, and returns nothing if none are found.
408
+ /// referenced from: https://github.com/vllm-project/vllm/blob/923797fea4d80a4dac4409ece3c450b84d5ba001/vllm/config.py#L557-L592
409
+ fn get_max_sequence_length_from_config ( model_config : & serde_json:: Value ) -> Option < usize > {
410
+ let possible_keys = [
411
+ // OPT
412
+ "max_position_embeddings" ,
413
+ // GPT-2
414
+ "n_positions" ,
415
+ // MPT
416
+ "max_seq_len" ,
417
+ // ChatGLM2
418
+ "seq_length" ,
419
+ // Others
420
+ "max_sequence_length" ,
421
+ "max_seq_length" ,
422
+ "seq_len" ,
423
+ ] ;
424
+
425
+ for key in possible_keys {
426
+ if let Some ( value) = model_config. get ( key) {
427
+ if let Some ( value) = value. as_u64 ( ) {
428
+ return Some ( value as usize ) ;
429
+ }
430
+ }
431
+ }
432
+ None
433
+ }
367
434
368
435
fn find_num_shards ( num_shard : Option < usize > ) -> usize {
369
436
// get the number of shards given `num_gpu` and `num_shard`
0 commit comments