Skip to content

Commit 009a2ba

Browse files
get_max_sequence_length() warning if user MAX_SEQUENCE_LENGTH > model MAX_SEQUENCE_LENGTH (#105)
- Modify `get_max_sequence_length()` to warn if USER-DEFINED `MAX_SEQUENCE_LENGTH` is greater than model's config value - Consolidated multiple return statements into a single return at the end of the function - Combined logging into a single `info!()` call - Introduced variable `result_max_sequence_length` to hold the final value - Added concise docstring for function documentation --------- Signed-off-by: Jefferson Fialho <[email protected]> Co-authored-by: Joe Runde <[email protected]>
1 parent c265390 commit 009a2ba

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

launcher/src/main.rs

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -491,32 +491,39 @@ fn num_cuda_devices() -> Option<usize> {
491491
let n_devices = devices.split(',').count();
492492
Some(n_devices)
493493
}
494+
494495
/// Finds a max sequence length for the model. In priority order:
495496
/// 1. MAX_SEQUENCE_LENGTH set by user
496497
/// 2. The sequence length specified in config.json
497-
/// 3. A default of 2048
498+
/// 3. A default of 2048
499+
/// ### Arguments
500+
/// * `max_sequence_length` - Optional user-defined maximum sequence length.
501+
/// * `config_path` - Path to the model configuration file.
502+
/// ### Returns
503+
/// The effective maximum sequence length to be used.
498504
fn get_max_sequence_length(max_sequence_length: Option<usize>, config_path: &Path) -> usize {
499-
if let Some(max_sequence_length) = max_sequence_length {
500-
info!(
501-
"Using configured max_sequence_length: {}",
502-
max_sequence_length
503-
);
504-
return max_sequence_length;
505-
}
505+
let mut length: Option<usize> = max_sequence_length;
506+
let mut source = "user-defined";
507+
506508
if let Ok(model_config) = get_config_json(config_path) {
507-
if let Some(length) = get_max_sequence_length_from_config(&model_config) {
508-
info!(
509-
"Loaded max_sequence_length from model config.json: {}",
510-
length
511-
);
512-
return length;
509+
if let Some(model_length) = get_max_sequence_length_from_config(&model_config) {
510+
if length.is_some_and(|length| length > model_length) {
511+
warn!("User-defined max_sequence_length ({}) is greater than the model's max_sequence_length ({})",
512+
length.unwrap(), model_length
513+
);
514+
}
515+
length.get_or_insert_with(|| {
516+
source = "model";
517+
model_length
518+
});
513519
}
514520
}
515-
info!(
516-
"Using default max_sequence_length: {}",
521+
let result = length.unwrap_or_else(|| {
522+
source = "default";
517523
DEFAULT_MAX_SEQUENCE_LENGTH
518-
);
519-
DEFAULT_MAX_SEQUENCE_LENGTH
524+
});
525+
info!("Using {} max_sequence_length: {}", source, result);
526+
return result;
520527
}
521528

522529
/// Opens the model's config.json file and reads into serde_json value

0 commit comments

Comments
 (0)