Skip to content

Commit 8ff0bf5

Browse files
authored
Make max_new_tokens optional, default to max_total_tokens - input_length (#353)
1 parent f474be2 commit 8ff0bf5

File tree

5 files changed

+50
-51
lines changed

5 files changed

+50
-51
lines changed

clients/python/lorax/client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def generate(
6767
merged_adapters: Optional[MergedAdapters] = None,
6868
api_token: Optional[str] = None,
6969
do_sample: bool = False,
70-
max_new_tokens: int = 20,
70+
max_new_tokens: Optional[int] = None,
7171
ignore_eos_token: bool = False,
7272
best_of: Optional[int] = None,
7373
repetition_penalty: Optional[float] = None,
@@ -101,7 +101,7 @@ def generate(
101101
API token for accessing private adapters
102102
do_sample (`bool`):
103103
Activate logits sampling
104-
max_new_tokens (`int`):
104+
max_new_tokens (`Optional[int]`):
105105
Maximum number of generated tokens
106106
ignore_eos_token (`bool`):
107107
Whether to ignore EOS tokens during generation
@@ -201,7 +201,7 @@ def generate_stream(
201201
merged_adapters: Optional[MergedAdapters] = None,
202202
api_token: Optional[str] = None,
203203
do_sample: bool = False,
204-
max_new_tokens: int = 20,
204+
max_new_tokens: Optional[int] = None,
205205
ignore_eos_token: bool = False,
206206
repetition_penalty: Optional[float] = None,
207207
return_full_text: bool = False,
@@ -232,7 +232,7 @@ def generate_stream(
232232
API token for accessing private adapters
233233
do_sample (`bool`):
234234
Activate logits sampling
235-
max_new_tokens (`int`):
235+
max_new_tokens (`Optional[int]`):
236236
Maximum number of generated tokens
237237
ignore_eos_token (`bool`):
238238
Whether to ignore EOS tokens during generation
@@ -388,7 +388,7 @@ async def generate(
388388
merged_adapters: Optional[MergedAdapters] = None,
389389
api_token: Optional[str] = None,
390390
do_sample: bool = False,
391-
max_new_tokens: int = 20,
391+
max_new_tokens: Optional[int] = None,
392392
ignore_eos_token: bool = False,
393393
best_of: Optional[int] = None,
394394
repetition_penalty: Optional[float] = None,
@@ -422,7 +422,7 @@ async def generate(
422422
API token for accessing private adapters
423423
do_sample (`bool`):
424424
Activate logits sampling
425-
max_new_tokens (`int`):
425+
max_new_tokens (`Optional[int]`):
426426
Maximum number of generated tokens
427427
ignore_eos_token (`bool`):
428428
Whether to ignore EOS tokens during generation
@@ -517,7 +517,7 @@ async def generate_stream(
517517
merged_adapters: Optional[MergedAdapters] = None,
518518
api_token: Optional[str] = None,
519519
do_sample: bool = False,
520-
max_new_tokens: int = 20,
520+
max_new_tokens: Optional[int] = None,
521521
ignore_eos_token: bool = False,
522522
repetition_penalty: Optional[float] = None,
523523
return_full_text: bool = False,
@@ -550,7 +550,7 @@ async def generate_stream(
550550
API token for accessing private adapters
551551
do_sample (`bool`):
552552
Activate logits sampling
553-
max_new_tokens (`int`):
553+
max_new_tokens (`Optional[int]`):
554554
Maximum number of generated tokens
555555
ignore_eos_token (`bool`):
556556
Whether to ignore EOS tokens during generation

clients/python/lorax/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class Parameters(BaseModel):
7979
# Activate logits sampling
8080
do_sample: bool = False
8181
# Maximum number of generated tokens
82-
max_new_tokens: int = 20
82+
max_new_tokens: Optional[int] = None
8383
# Whether to ignore the EOS token during generation
8484
ignore_eos_token: bool = False
8585
# The parameter for repetition penalty. 1.0 means no penalty.

docs/reference/openapi.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,9 +745,9 @@
745745
"max_new_tokens": {
746746
"type": "integer",
747747
"format": "int32",
748-
"default": "20",
748+
"default": "null",
749+
"nullable": true,
749750
"minimum": 0.0,
750-
"exclusiveMaximum": 512.0,
751751
"exclusiveMinimum": 0.0
752752
},
753753
"ignore_eos_token": {

router/src/lib.rs

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,9 @@ pub(crate) struct GenerateParameters {
220220
#[serde(default)]
221221
#[schema(default = "false", example = true)]
222222
pub do_sample: bool,
223-
#[serde(default = "default_max_new_tokens")]
224-
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
225-
pub max_new_tokens: u32,
223+
#[serde(default)]
224+
#[schema(exclusive_minimum = 0, default = "null")]
225+
pub max_new_tokens: Option<u32>,
226226
#[serde(default)]
227227
#[schema(default = "false", example = true)]
228228
pub ignore_eos_token: bool,
@@ -267,10 +267,6 @@ pub(crate) struct GenerateParameters {
267267
pub response_format: Option<ResponseFormat>,
268268
}
269269

270-
fn default_max_new_tokens() -> u32 {
271-
20
272-
}
273-
274270
fn default_parameters() -> GenerateParameters {
275271
GenerateParameters {
276272
adapter_id: None,
@@ -284,7 +280,7 @@ fn default_parameters() -> GenerateParameters {
284280
top_p: None,
285281
typical_p: None,
286282
do_sample: false,
287-
max_new_tokens: default_max_new_tokens(),
283+
max_new_tokens: None,
288284
ignore_eos_token: false,
289285
return_full_text: None,
290286
stop: Vec::new(),
@@ -621,10 +617,7 @@ impl From<CompletionRequest> for CompatGenerateRequest {
621617
top_p: req.top_p,
622618
typical_p: None,
623619
do_sample: !req.n.is_none(),
624-
max_new_tokens: req
625-
.max_tokens
626-
.map(|x| x as u32)
627-
.unwrap_or(default_max_new_tokens()),
620+
max_new_tokens: req.max_tokens.map(|x| x as u32),
628621
ignore_eos_token: req.ignore_eos_token.unwrap_or(false),
629622
return_full_text: req.echo,
630623
stop: req.stop,
@@ -658,10 +651,7 @@ impl From<ChatCompletionRequest> for CompatGenerateRequest {
658651
top_p: req.top_p,
659652
typical_p: None,
660653
do_sample: !req.n.is_none(),
661-
max_new_tokens: req
662-
.max_tokens
663-
.map(|x| x as u32)
664-
.unwrap_or(default_max_new_tokens()),
654+
max_new_tokens: req.max_tokens.map(|x| x as u32),
665655
ignore_eos_token: req.ignore_eos_token.unwrap_or(false),
666656
return_full_text: None,
667657
stop: req.stop,

router/src/validation.rs

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ impl Validation {
6565
&self,
6666
inputs: String,
6767
truncate: Option<usize>,
68-
max_new_tokens: u32,
68+
max_new_tokens: Option<u32>,
6969
) -> Result<(String, usize), ValidationError> {
7070
// If we have a fast tokenizer
7171
if let Some(sender) = &self.sender {
@@ -81,16 +81,18 @@ impl Validation {
8181
// Unwrap is safe here
8282
let (inputs, input_length) = response_receiver.await.unwrap()?;
8383

84-
// Get total tokens
85-
let total_tokens = input_length + max_new_tokens as usize;
86-
87-
// Validate MaxTotalTokens
88-
if total_tokens > self.max_total_tokens {
89-
return Err(ValidationError::MaxTotalTokens(
90-
self.max_total_tokens,
91-
input_length,
92-
max_new_tokens,
93-
));
84+
if let Some(max_new_tokens) = max_new_tokens {
85+
// Get total tokens
86+
let total_tokens = input_length + max_new_tokens as usize;
87+
88+
// Validate MaxTotalTokens
89+
if total_tokens > self.max_total_tokens {
90+
return Err(ValidationError::MaxTotalTokens(
91+
self.max_total_tokens,
92+
input_length,
93+
max_new_tokens,
94+
));
95+
}
9496
}
9597

9698
// Validate InputLength
@@ -111,12 +113,13 @@ impl Validation {
111113
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
112114
let input_length = truncate.unwrap_or(self.max_input_length);
113115

114-
// Validate MaxNewTokens
115-
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
116-
return Err(ValidationError::MaxNewTokens(
117-
self.max_total_tokens - self.max_input_length,
118-
max_new_tokens,
119-
));
116+
if let Some(max_new_tokens) = max_new_tokens {
117+
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
118+
return Err(ValidationError::MaxNewTokens(
119+
self.max_total_tokens - self.max_input_length,
120+
max_new_tokens,
121+
));
122+
}
120123
}
121124

122125
Ok((inputs, input_length))
@@ -231,7 +234,7 @@ impl Validation {
231234
})
232235
.unwrap_or(Ok(0))?;
233236

234-
if max_new_tokens == 0 {
237+
if max_new_tokens.is_some() && max_new_tokens.unwrap() == 0 {
235238
return Err(ValidationError::NegativeMaxNewTokens);
236239
}
237240

@@ -294,13 +297,19 @@ impl Validation {
294297
schema,
295298
return_k_alternatives,
296299
};
300+
301+
let effective_max_new_tokens =
302+
max_new_tokens.unwrap_or((self.max_total_tokens - input_length) as u32);
297303
let stopping_parameters = StoppingCriteriaParameters {
298-
max_new_tokens,
304+
max_new_tokens: effective_max_new_tokens,
299305
stop_sequences,
300306
ignore_eos_token,
301307
};
302308

303-
metrics::histogram!("lorax_request_max_new_tokens", max_new_tokens as f64);
309+
metrics::histogram!(
310+
"lorax_request_max_new_tokens",
311+
effective_max_new_tokens as f64
312+
);
304313

305314
Ok(ValidGenerateRequest {
306315
inputs,
@@ -461,7 +470,7 @@ mod tests {
461470
max_total_tokens,
462471
);
463472

464-
let max_new_tokens = 10;
473+
let max_new_tokens = Some(10);
465474
match validation
466475
.validate_input("Hello".to_string(), None, max_new_tokens)
467476
.await
@@ -488,7 +497,7 @@ mod tests {
488497
max_total_tokens,
489498
);
490499

491-
let max_new_tokens = 10;
500+
let max_new_tokens = Some(10);
492501
match validation
493502
.validate_input("Hello".to_string(), None, max_new_tokens)
494503
.await
@@ -588,7 +597,7 @@ mod tests {
588597
inputs: "Hello".to_string(),
589598
parameters: GenerateParameters {
590599
top_p: Some(0.99),
591-
max_new_tokens: 1,
600+
max_new_tokens: Some(1),
592601
..default_parameters()
593602
},
594603
},
@@ -614,7 +623,7 @@ mod tests {
614623
inputs: "Hello".to_string(),
615624
parameters: GenerateParameters {
616625
top_p: None,
617-
max_new_tokens: 1,
626+
max_new_tokens: Some(1),
618627
..default_parameters()
619628
},
620629
},

0 commit comments

Comments
 (0)