Skip to content

Commit 1f4cfbe

Browse files
feat: allow configuration of the max soft prompt length (#33)
Instead of defaulting to a hard-coded 256, the default soft prompt length is now 50% of the max sequence length. The env var MAX_PROMPT_PREFIX_LENGTH can be used to override this default if desired Signed-off-by: Travis Johnson <[email protected]> Co-authored-by: TRAVIS JOHNSON <[email protected]>
1 parent ac1f655 commit 1f4cfbe

File tree

4 files changed

+24
-8
lines changed

4 files changed

+24
-8
lines changed

server/text_generation_server/models/causal_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def __init__(
559559
model_path, AutoModelForCausalLM, dtype, quantize, model_config, max_sequence_length
560560
)
561561

562-
super(CausalLM, self).__init__(inference_engine, dtype)
562+
super(CausalLM, self).__init__(inference_engine, dtype, max_sequence_length)
563563

564564
if self.model.config.pad_token_id is not None:
565565
self.tokenizer.pad_token_id = self.model.config.pad_token_id

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def __init__(
385385
model_path, auto_model_class, dtype, quantize, model_config, max_sequence_length
386386
)
387387

388-
super(FlashCausalLM, self).__init__(inference_engine, dtype)
388+
super(FlashCausalLM, self).__init__(inference_engine, dtype, max_sequence_length)
389389
self.use_position_ids = True
390390

391391
if self.model.config.pad_token_id is not None:

server/text_generation_server/models/model.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import math
23
import os
34
import types
45

@@ -19,9 +20,6 @@
1920

2021
B = TypeVar("B", bound=Batch)
2122

22-
# TODO make configurable, possibly based on configured max seq length
23-
MAX_PROMPT_PREFIX_LENGTH = 256
24-
2523
CUDA_PAD_TO_MULT_OF_8 = os.getenv("CUDA_PAD_TO_MULT_OF_8", "true").lower() != "false"
2624
PT2_COMPILE = os.getenv("PT2_COMPILE", "false").lower() != "false"
2725

@@ -33,7 +31,7 @@
3331

3432

3533
class Model(ABC):
36-
def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
34+
def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype, max_seq_length: Optional[int] = None):
3735
self.engine = engine
3836
self.config, self.tokenizer, self.model = engine.get_components()
3937
self.device = engine.get_device()
@@ -50,6 +48,24 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
5048

5149
if prompt_prefix_supported:
5250
# Set up prefix cache
51+
52+
if max_seq_length is None:
53+
# shouldn't be None, but just in case since the parameter is passed through as Optional
54+
max_seq_length = 2048
55+
56+
# default value to 50% of the max sequence length
57+
max_prompt_prefix_length = math.ceil(max_seq_length * 0.5)
58+
if (max_prompt_prefix_env_var := os.getenv("MAX_PROMPT_PREFIX_LENGTH")):
59+
try:
60+
max_prompt_prefix_env_var = int(max_prompt_prefix_env_var)
61+
except ValueError as exc:
62+
raise ValueError("Invalid value for MAX_PROMPT_PREFIX_LENGTH") from exc
63+
64+
if max_prompt_prefix_env_var > max_seq_length - 1:
65+
raise ValueError(f"Value for the MAX_PROMPT_PREFIX_LENGTH ({max_prompt_prefix_env_var}) cannot be larger than the max sequence length - 1 ({max_seq_length - 1})")
66+
67+
max_prompt_prefix_length = max_prompt_prefix_env_var
68+
5369
decoder_start_token_id = self.model.config.decoder_start_token_id
5470
if decoder_start_token_id is None:
5571
decoder_start_token_id = self.tokenizer.bos_token_id
@@ -65,7 +81,7 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
6581
self.prefix_cache = PrefixCache(
6682
device=self.device,
6783
dtype=dtype,
68-
max_length=MAX_PROMPT_PREFIX_LENGTH,
84+
max_length=max_prompt_prefix_length,
6985
encoder_decoder=self.model.config.is_encoder_decoder,
7086
return_zero=return_zero,
7187
decoder_start_tok_embedding=self.word_embeddings(

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def __init__(
557557
inference_engine = get_inference_engine_class(deployment_framework)(
558558
model_path, AutoModelForSeq2SeqLM, dtype, quantize, model_config, max_sequence_length
559559
)
560-
super(Seq2SeqLM, self).__init__(inference_engine, dtype)
560+
super(Seq2SeqLM, self).__init__(inference_engine, dtype, max_sequence_length)
561561

562562
bos_token_id = self.model.config.decoder_start_token_id
563563
if bos_token_id is None:

0 commit comments

Comments
 (0)