Skip to content

Commit 8bd9ac3

Browse files
authored
Add max_prompt_tokens (#202)
1 parent 9731574 commit 8bd9ac3

File tree

4 files changed

+76
-26
lines changed

4 files changed

+76
-26
lines changed

tests/common/vllm_test.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,23 @@ def print_debug(*args):
101101

102102

103103
@parameterized_class(
104-
("tensor_parallel_size", "engine_num", "use_v1", "repeat_times", "enable_history", "use_async"),
104+
(
105+
"tensor_parallel_size",
106+
"engine_num",
107+
"use_v1",
108+
"repeat_times",
109+
"enable_history",
110+
"use_async",
111+
"max_model_len",
112+
),
105113
[
106-
(1, 2, False, 2, True, False),
107-
(2, 2, False, 1, False, True),
108-
(2, 2, True, 2, True, False),
109-
(1, 2, True, 1, False, True),
110-
(2, 1, True, 3, True, True),
114+
(1, 2, False, 2, True, False, None),
115+
(1, 2, False, 2, True, True, 20),
116+
(1, 2, False, 2, True, False, 20),
117+
(2, 2, False, 1, False, True, None),
118+
(2, 2, True, 2, True, False, None),
119+
(1, 2, True, 1, False, True, None),
120+
(2, 1, True, 3, True, True, None),
111121
],
112122
)
113123
class ModelWrapperTest(RayUnittestBaseAysnc):
@@ -116,13 +126,17 @@ def setUp(self):
116126
self.config = get_template_config()
117127
self.config.mode = "explore"
118128
self.config.model.model_path = get_model_path()
129+
self.config.model.max_model_len = self.max_model_len
119130
self.config.explorer.rollout_model.engine_num = self.engine_num
120131
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
121132
self.config.explorer.rollout_model.use_v1 = self.use_v1
122133
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
123134
self.config.algorithm.repeat_times = self.repeat_times
124135
self.config.explorer.rollout_model.enable_history = self.enable_history
125136
self.config.check_and_update()
137+
from pprint import pprint
138+
139+
pprint(self.config)
126140
self.engines, self.auxiliary_engines = create_inference_models(self.config)
127141
self.model_wrapper = ModelWrapper(
128142
self.engines[0], model_type="vllm_async", enable_history=self.enable_history
@@ -191,7 +205,12 @@ async def test_generate(
191205
"content": results[0].response_text,
192206
}
193207
)
194-
exp = self.model_wrapper.convert_messages_to_experience(messages)
208+
if self.max_model_len is not None:
209+
with self.assertRaises(ValueError):
210+
exp = self.model_wrapper.convert_messages_to_experience(messages)
211+
return
212+
else:
213+
exp = self.model_wrapper.convert_messages_to_experience(messages)
195214
tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)
196215
result_dict = tokenizer.apply_chat_template(
197216
messages,

trinity/common/config.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from trinity.common.constants import (
1212
EXPLORER_NAME,
13+
MAX_MODEL_LEN,
1314
TRAINER_NAME,
1415
OpType,
1516
PromptType,
@@ -178,7 +179,7 @@ class ModelConfig:
178179
model_path: str = ""
179180
critic_model_path: str = ""
180181
max_model_len: Optional[int] = None
181-
max_prompt_tokens: Optional[int] = None # deprecated
182+
max_prompt_tokens: Optional[int] = None
182183
max_response_tokens: Optional[int] = None
183184
custom_chat_template: Optional[str] = None
184185

@@ -203,7 +204,7 @@ class InferenceModelConfig:
203204
# if not set, use `model.max_model_len`
204205
max_model_len: Optional[int] = None
205206
# if not set, use `model.max_prompt_tokens`
206-
max_prompt_tokens: Optional[int] = None # deprecated
207+
max_prompt_tokens: Optional[int] = None
207208
# if not set, use `model.max_response_tokens`
208209
max_response_tokens: Optional[int] = None
209210

@@ -775,24 +776,40 @@ def check_and_update(self) -> None: # noqa: C901
775776
self.model.critic_model_path = self.model.model_path
776777

777778
# check explorer
779+
if self.model.max_model_len is None:
780+
from transformers import AutoConfig, AutoTokenizer
781+
from transformers.tokenization_utils_base import LARGE_INTEGER
782+
783+
tokenizer = AutoTokenizer.from_pretrained(self.model.model_path)
784+
config = AutoConfig.from_pretrained(self.model.model_path)
785+
max_model_len = min(
786+
getattr(tokenizer, "model_max_length", LARGE_INTEGER),
787+
getattr(config, "max_position_embeddings", LARGE_INTEGER),
788+
)
789+
if max_model_len >= LARGE_INTEGER:
790+
max_model_len = MAX_MODEL_LEN
791+
logger.warning(
792+
f"Failed to get `max_model_len` from model {self.model.model_path}, use {MAX_MODEL_LEN} instead."
793+
)
794+
self.model.max_model_len = max_model_len
795+
if (
796+
self.model.max_prompt_tokens is None
797+
or self.model.max_prompt_tokens >= self.model.max_model_len
798+
):
799+
self.model.max_prompt_tokens = self.model.max_model_len - 1
800+
logger.warning(f"`max_prompt_tokens` is set to {self.model.max_prompt_tokens}.")
801+
if (
802+
self.model.max_response_tokens is None
803+
or self.model.max_response_tokens > self.model.max_model_len
804+
):
805+
self.model.max_response_tokens = self.model.max_model_len
806+
logger.warning(f"`max_response_tokens` is set to {self.model.max_response_tokens}.")
807+
if self.explorer.rollout_model.max_model_len is None:
808+
self.explorer.rollout_model.max_model_len = self.model.max_model_len
778809
if self.explorer.rollout_model.max_prompt_tokens is None:
779810
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
780811
if self.explorer.rollout_model.max_response_tokens is None:
781812
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
782-
if self.explorer.rollout_model.max_model_len is None:
783-
self.explorer.rollout_model.max_model_len = self.model.max_model_len
784-
if (
785-
self.explorer.rollout_model.max_model_len is None
786-
and self.explorer.rollout_model.max_prompt_tokens is not None
787-
and self.explorer.rollout_model.max_response_tokens is not None
788-
):
789-
logger.warning(
790-
"`max_prompt_tokens` is deprecated, please set `max_model_len` directly."
791-
)
792-
self.explorer.rollout_model.max_model_len = (
793-
self.explorer.rollout_model.max_prompt_tokens
794-
+ self.explorer.rollout_model.max_response_tokens
795-
)
796813

797814
# check synchronizer
798815
self.synchronizer.ray_namespace = self.ray_namespace

trinity/common/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
PLUGIN_DIRS_ENV_VAR = "TRINITY_PLUGIN_DIRS"
1717

1818

19+
# constants
20+
21+
MAX_MODEL_LEN = 4096
22+
23+
1924
# enumerate types
2025

2126

trinity/common/models/vllm_model.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
temperature=0.0,
5454
max_tokens=config.max_response_tokens,
5555
min_tokens=1,
56-
truncate_prompt_tokens=config.max_model_len - 1, # type: ignore [operator]
56+
truncate_prompt_tokens=config.max_prompt_tokens,
5757
skip_special_tokens=True,
5858
include_stop_str_in_output=False,
5959
output_kind=RequestOutputKind.FINAL_ONLY,
@@ -100,6 +100,10 @@ def __init__(
100100
self.api_server_host = None
101101
self.api_server_port = None
102102

103+
async def _initialize_tokenizer(self):
104+
self.tokenizer = await self.async_llm.get_tokenizer()
105+
self.tokenizer.truncation_side = "left"
106+
103107
async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]:
104108
"""Chat with the model with a list of messages in async.
105109
@@ -111,7 +115,7 @@ async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]:
111115
A list of experiences.
112116
"""
113117
if self.tokenizer is None:
114-
self.tokenizer = await self.async_llm.get_tokenizer()
118+
await self._initialize_tokenizer()
115119
if self.chat_template is None:
116120
self.chat_template = self.tokenizer.get_chat_template()
117121
if messages[-1]["role"] == "assistant":
@@ -141,7 +145,12 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]:
141145
Returns:
142146
A list of experiences.
143147
"""
144-
output = await self._generate_internal(prompt=prompt, **kwargs)
148+
if self.tokenizer is None:
149+
await self._initialize_tokenizer()
150+
token_ids = self.tokenizer( # type: ignore
151+
prompt, truncation=True, max_length=self.config.max_prompt_tokens, return_tensors="pt"
152+
)["input_ids"][0].tolist()
153+
output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs)
145154
experiences = [
146155
Experience(
147156
tokens=torch.cat(

0 commit comments

Comments
 (0)