Skip to content

Commit c615ee7

Browse files
authored
Add truncate_status to experience (#407)
1 parent 89dd059 commit c615ee7

File tree

11 files changed

+128
-26
lines changed

11 files changed

+128
-26
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ model:
174174
- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`.
175175
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
176176
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.
177-
- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`.
177+
- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. This function does not work with openai api mode.
178178

179179
```{tip}
180180
If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API.

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ model:
174174
- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
175175
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
176176
- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
177-
- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。
177+
- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。在 OpenAI API 模式下不生效。
178178

179179
```{tip}
180180
如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。

tests/common/experience_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_assertions(self):
175175
# prompt_length must be > 0
176176
with self.assertRaises(AssertionError):
177177
Experience(tokens=[1, 2, 3], prompt_length=0)
178-
# tokens must be longer than prompt_length for single-turn
178+
# tokens must be larger than prompt_length for single-turn
179179
with self.assertRaises(AssertionError):
180180
Experience(tokens=[1, 2], prompt_length=2)
181181
# DPO: tokens must match prompt_length

tests/common/vllm_test.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ async def test_generate(
224224
[
225225
(20, 19, None),
226226
(20, None, 1),
227+
(20, 5, 15),
227228
],
228229
)
229230
class TestModelLen(RayUnittestBaseAysnc):
@@ -240,6 +241,7 @@ def setUp(self):
240241

241242
self.engines, self.auxiliary_engines = create_inference_models(self.config)
242243
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
244+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)
243245

244246
async def test_model_len(self):
245247
await self.model_wrapper.prepare()
@@ -248,18 +250,30 @@ async def test_model_len(self):
248250
{"role": "user", "content": "What's the weather like today?"},
249251
]
250252

253+
def _check_experience(exp):
254+
# check prompt content and length
255+
encoded_prompt = self.tokenizer.encode(exp.prompt_text, add_special_tokens=False)
256+
self.assertEqual(len(encoded_prompt), exp.prompt_length)
257+
self.assertLessEqual(exp.prompt_length, self.config.model.max_prompt_tokens)
258+
# check response content and length
259+
encoded_response = self.tokenizer.encode(exp.response_text, add_special_tokens=False)
260+
self.assertEqual(len(encoded_response), len(exp.tokens) - exp.prompt_length)
261+
self.assertLessEqual(
262+
len(exp.tokens) - exp.prompt_length, self.config.model.max_response_tokens
263+
)
264+
# check full sequence
265+
self.assertLessEqual(len(exp.tokens), self.config.model.max_model_len)
266+
251267
# For vllm engine, max_prompt_tokens and max_response_tokens work
252268
response = self.model_wrapper.chat(messages)
253269
self.assertEqual(len(response), 1)
254-
self.assertEqual(len(response[0].tokens), self.config.model.max_model_len)
270+
if self.max_prompt_tokens == 5:
271+
self.assertEqual(response[0].truncate_status, "prompt_truncated")
272+
_check_experience(response[0])
273+
255274
exps = self.model_wrapper.extract_experience_from_history()
256275
self.assertEqual(len(exps), 1)
257-
# check prompt length, response length, max_model_len
258-
self.assertEqual(exps[0].prompt_length, self.config.model.max_prompt_tokens)
259-
self.assertEqual(
260-
len(exps[0].tokens) - exps[0].prompt_length, self.config.model.max_response_tokens
261-
)
262-
self.assertLessEqual(len(response[0].tokens), self.config.model.max_model_len)
276+
_check_experience(exps[0])
263277

264278
# For openai api, max_prompt_tokens and max_response_tokens do not work
265279
openai_client = self.model_wrapper.get_openai_client()

tests/trainer/trainer_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,3 +1082,44 @@ def test_trainer(self):
10821082
def tearDown(self):
10831083
# remove dir only when the test passed
10841084
shutil.rmtree(self.config.checkpoint_job_dir)
1085+
1086+
1087+
class TestTrainerPromptTruncation(BaseTrainerCase):
1088+
def test_trainer(self):
1089+
self.config.model.max_model_len = 20
1090+
self.config.model.max_prompt_tokens = 5
1091+
self.config.model.max_response_tokens = 15
1092+
self.config.model.enable_prompt_truncation = True
1093+
self.config.algorithm.algorithm_type = "grpo"
1094+
self.config.algorithm.advantage_fn = "grpo"
1095+
self.config.algorithm.kl_loss_fn = "none"
1096+
self.config.algorithm.repeat_times = 2
1097+
self.config.buffer.batch_size = 4
1098+
self.config.buffer.total_steps = 2
1099+
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
1100+
self.config.check_and_update()
1101+
both(self.config)
1102+
1103+
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
1104+
rollout_metrics = parser.metric_list("rollout")
1105+
self.assertTrue(len(rollout_metrics) > 0)
1106+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
1107+
actor_metrics = parser.metric_list("actor")
1108+
self.assertTrue(len(actor_metrics) > 0)
1109+
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2)
1110+
max_prompt_length = parser.metric_values("prompt_length/max")
1111+
self.assertEqual(max(max_prompt_length), 5)
1112+
min_prompt_length = parser.metric_values("prompt_length/min")
1113+
self.assertEqual(min(min_prompt_length), 5)
1114+
max_response_length = parser.metric_values("response_length/max")
1115+
self.assertEqual(max(max_response_length), 1)
1116+
min_response_length = parser.metric_values("response_length/min")
1117+
self.assertEqual(min(min_response_length), 1)
1118+
final_loss = parser.metric_values("actor/final_loss")
1119+
self.assertEqual(final_loss[0], 0.0)
1120+
grad_norm = parser.metric_values("actor/grad_norm")
1121+
self.assertEqual(grad_norm[0], 0.0)
1122+
1123+
def tearDown(self):
1124+
# remove dir only when the test passed
1125+
shutil.rmtree(self.config.checkpoint_job_dir)

trinity/common/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,8 @@ class ModelConfig:
457457
max_response_tokens: Optional[int] = None
458458
# the minimum number of tokens for the response
459459
min_response_tokens: int = 1
460-
# whether to truncate the prompt; if set to True, the prompt will be truncated to `max_prompt_tokens` tokens.
460+
# whether to truncate the prompt; if set to True, the prompt will be truncated to `max_prompt_tokens` tokens;
461+
# not applicable for OpenAI API
461462
enable_prompt_truncation: bool = True
462463

463464
# lora config
@@ -1192,7 +1193,7 @@ def _check_model(self) -> None:
11921193
if model.enable_prompt_truncation is True:
11931194
if model.max_prompt_tokens is None:
11941195
raise ValueError(
1195-
"When `model.enable_prompt_truncation` is True, `model.max_prompt_tokens` must be set properly."
1196+
"When `model.enable_prompt_truncation` is True, `model.max_prompt_tokens` must be set properly. This function does not work with OpenAI API mode."
11961197
)
11971198
logger.warning(
11981199
f"`enable_prompt_truncation` is set to True; the prompt will be truncated to `max_prompt_tokens`={model.max_prompt_tokens} tokens if it is too long."

trinity/common/experience.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ class Experience:
104104
token_level_reward: Optional[Tensor] = None # [resp_length]
105105
advantages: Optional[Tensor] = None # [resp_length]
106106
returns: Optional[Tensor] = None # [resp_length]
107+
truncate_status: Optional[
108+
str
109+
] = None # The status of truncation, e.g., "prompt_truncated", "response_truncated"; Not working for openai api
107110
info: dict = field(
108111
default_factory=dict
109112
) # Additional information about the experience, can also be used to store custom fields
@@ -140,6 +143,7 @@ def __init__( # noqa: C901
140143
token_level_reward=None,
141144
advantages=None,
142145
returns=None,
146+
truncate_status=None,
143147
info=None,
144148
metrics=None,
145149
prompt_length=1,
@@ -165,10 +169,13 @@ def __init__( # noqa: C901
165169
assert (
166170
prompt_length > 0
167171
), "Prompt length must be greater than 0 for single-turn experiences."
168-
assert (
169-
len(tokens) > prompt_length
170-
), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}."
171-
action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool)
172+
if truncate_status != "prompt_truncated":
173+
assert (
174+
len(tokens) > prompt_length
175+
), f"Token ids must be larger than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}."
176+
action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool)
177+
else:
178+
action_mask = torch.zeros(len(logprobs), dtype=torch.bool)
172179
elif experience_type == "dpo":
173180
prompt_length = len(tokens)
174181
if eid is None:
@@ -196,6 +203,7 @@ def __init__( # noqa: C901
196203
self.experience_type = experience_type
197204
self.info = info or {}
198205
self.metrics = metrics or {}
206+
self.truncate_status = truncate_status
199207
self.prompt_length = prompt_length
200208
self.response_text = response_text
201209
self.prompt_text = prompt_text
@@ -264,6 +272,8 @@ def to_dict(self) -> dict:
264272
res["rejected_messages"] = self.rejected_messages
265273
if self.reward is not None:
266274
res["reward"] = float(self.reward)
275+
if self.truncate_status is not None:
276+
res["truncate_status"] = self.truncate_status
267277
return res
268278

269279
@classmethod

trinity/common/models/vllm_model.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,36 @@ async def generate(
191191
"""
192192
if self.tokenizer is None:
193193
await self._initialize_tokenizer()
194+
195+
# Tokenize once without truncation to check if truncation is needed
194196
token_ids = self.tokenizer( # type: ignore
195197
prompt,
196-
truncation=self.config.enable_prompt_truncation,
197-
max_length=self.config.max_prompt_tokens,
198+
truncation=False,
198199
return_tensors="pt",
199-
)["input_ids"][0].tolist()
200+
)[
201+
"input_ids"
202+
][0].tolist()
203+
204+
# Check if truncation is needed and apply it
205+
if self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None:
206+
if len(token_ids) > self.config.max_prompt_tokens:
207+
self.logger.warning(
208+
f"Prompt was truncated to {self.config.max_prompt_tokens} tokens"
209+
)
210+
token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response
211+
return [
212+
Experience(
213+
tokens=token_ids,
214+
logprobs=torch.zeros(1, dtype=torch.float32),
215+
prompt_length=len(token_ids) - 1,
216+
prompt_text=self.tokenizer.decode(token_ids[:-1]),
217+
response_text=self.tokenizer.decode(token_ids[-1]),
218+
truncate_status="prompt_truncated",
219+
reward=0.0,
220+
)
221+
for i in range(kwargs.get("n", 1))
222+
]
223+
200224
output = await self._generate_internal(
201225
prompt={"prompt_token_ids": token_ids}, lora_request=lora_request, **kwargs
202226
)
@@ -397,10 +421,10 @@ async def convert_messages_to_experience(
397421

398422
# Truncate tokens if they exceed the length limit
399423
assert token_ids is not None
400-
is_truncated = False # TODO: add to experience itself
424+
truncate_status = None
401425
if self.config.max_model_len is not None and self.config.max_model_len > 0:
402426
if len(token_ids) > self.config.max_model_len - 1:
403-
is_truncated = True
427+
truncate_status = "response_truncated"
404428
self.logger.warning(
405429
f"Warning: {len(token_ids) = } exceeds the length limit {self.config.max_model_len-1 = }"
406430
)
@@ -417,7 +441,7 @@ async def convert_messages_to_experience(
417441
prompt_length=prompt_length,
418442
action_mask=action_mask[prompt_length:], # Exclude the prompt tokens
419443
messages=messages,
420-
info={"is_truncated": is_truncated},
444+
truncate_status=truncate_status,
421445
)
422446

423447
async def shutdown(self):

trinity/common/workflows/envs/frozen_lake/workflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ async def run_async(self) -> List[Experience]:
280280
self.step_count = 0
281281
self.action = None
282282
terminate_reason = None
283+
truncate_status = None
283284

284285
# Initialize messages
285286
messages = []
@@ -318,6 +319,7 @@ async def run_async(self) -> List[Experience]:
318319
self.done = False
319320
self.step_rewards.append(0)
320321
terminate_reason = "max_tokens_reached"
322+
truncate_status = "response_truncated"
321323
break
322324

323325
# Get action from the model
@@ -360,6 +362,7 @@ async def run_async(self) -> List[Experience]:
360362
"env_done": 1 if self.done else 0,
361363
"test_score": final_reward,
362364
},
365+
truncate_status=truncate_status,
363366
)
364367
return [experience]
365368

trinity/common/workflows/workflow.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,12 @@ def set_repeat_times(self, repeat_times, run_id_base):
165165
self.repeat_times = repeat_times
166166
self.run_id_base = run_id_base
167167

168-
def process_messages_to_experience(self, messages, reward, info={}) -> Experience:
168+
def process_messages_to_experience(
169+
self, messages, reward, info={}, truncate_status=None
170+
) -> Experience:
169171
converted_experience = self.model.convert_messages_to_experience(messages)
170172

171-
if converted_experience.info.get("is_truncated", False):
173+
if converted_experience.truncate_status == "response_truncated":
172174
reward = 0.0
173175

174176
tokens = converted_experience.tokens
@@ -188,6 +190,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc
188190
prompt_length=converted_experience.prompt_length,
189191
prompt_text=converted_experience.prompt_text,
190192
response_text=converted_experience.response_text,
193+
truncate_status=converted_experience.truncate_status or truncate_status,
191194
reward=reward,
192195
logprobs=log_probs,
193196
info=info,

0 commit comments

Comments
 (0)