Skip to content

Commit 00f3b27

Browse files
authored
ModelWrapper automatically record Experience (#123)
1 parent 0065d9c commit 00f3b27

File tree

14 files changed

+829
-777
lines changed

14 files changed

+829
-777
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,10 @@ explorer:
321321
max_retry_times: 2
322322
env_vars: {}
323323
rollout_model:
324-
engine_type: vllm_async
324+
engine_type: vllm
325325
engine_num: 1
326326
tensor_parallel_size: 1
327+
enable_history: False
327328
auxiliary_models:
328329
- model_path: /PATH/TO/MODEL
329330
tensor_parallel_size: 1
@@ -336,9 +337,10 @@ explorer:
336337
- `max_timeout`: Maximum time (in seconds) for a workflow to complete.
337338
- `max_retry_times`: Maximum number of retries for a workflow.
338339
- `env_vars`: Environment variables to be set for every workflow runners.
339-
- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`.
340+
- `rollout_model.engine_type`: Type of inference engine. For now, only `vllm_async` and `vllm` is supported, they have the same meaning and both use the asynchronous engine. In subsequent versions, only `vllm` may be retained for simplicity.
340341
- `rollout_model.engine_num`: Number of inference engines.
341342
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.
343+
- `rollout_model.enable_history`: Whether to enable model call history recording. If set to `True`, the model wrapper automatically records the return experiences of model calls. Please periodically extract the history via `extract_experience_from_history` to avoid out-of-memory issues. Default is `False`.
342344
- `auxiliary_models`: Additional models used for custom workflows.
343345
- `eval_interval`: Interval (in steps) for evaluating the model.
344346
- `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting.

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ During initialization, `Workflow` receives the following parameters:
122122

123123
```{tip}
124124
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
125+
And the `model` field when calling openai API can be obtained via `openai_client.models.list().data[0].id`.
125126
```
126127

127128
Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "trinity-rft"
7-
version = "0.2.0"
7+
version = "0.2.1.dev0"
88
authors = [
99
{name="Trinity-RFT Team", email="[email protected]"},
1010
]

tests/common/vllm_test.py

Lines changed: 110 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import unittest
33

44
import torch
5+
from parameterized import parameterized_class
56
from transformers import AutoTokenizer
67

7-
from tests.tools import RayUnittestBase, get_template_config
8+
from tests.tools import RayUnittestBase, RayUnittestBaseAysnc, get_template_config
89
from trinity.common.models import create_inference_models
910
from trinity.common.models.model import ModelWrapper
1011
from trinity.common.models.utils import (
@@ -82,12 +83,59 @@ def get_model_path() -> str:
8283
"""
8384

8485

85-
class BaseTestModelWrapper:
86-
def test_generate(self):
86+
@parameterized_class(
87+
("tensor_parallel_size", "engine_num", "use_v1", "repeat_times", "enable_history", "use_async"),
88+
[
89+
(1, 2, False, 2, True, False),
90+
(2, 2, False, 1, False, True),
91+
(2, 2, True, 2, True, False),
92+
(1, 2, True, 1, False, True),
93+
(2, 1, True, 3, True, True),
94+
],
95+
)
96+
class ModelWrapperTest(RayUnittestBaseAysnc):
97+
def setUp(self):
98+
# configure the model
99+
self.config = get_template_config()
100+
self.config.mode = "explore"
101+
self.config.model.model_path = get_model_path()
102+
self.config.explorer.rollout_model.engine_num = self.engine_num
103+
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
104+
self.config.explorer.rollout_model.use_v1 = self.use_v1
105+
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
106+
self.config.algorithm.repeat_times = self.repeat_times
107+
self.config.explorer.rollout_model.enable_history = self.enable_history
108+
self.config.check_and_update()
109+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
110+
self.model_wrapper = ModelWrapper(
111+
self.engines[0], model_type="vllm_async", enable_history=self.enable_history
112+
)
113+
114+
async def test_generate(
115+
self,
116+
):
87117
prompts = ["Hello, world!", "Hello, my name is"]
88118
n = self.config.algorithm.repeat_times
89-
results = self.model_wrapper.generate(prompts, n=n, temperature=1.0)
90-
self.assertEqual(len(results), len(prompts) * n)
119+
if self.use_async:
120+
generate_results = await self.model_wrapper.generate_async(
121+
prompts, n=n, temperature=1.0
122+
)
123+
else:
124+
generate_results = self.model_wrapper.generate(prompts, n=n, temperature=1.0)
125+
self.assertEqual(len(generate_results), len(prompts) * n)
126+
if self.config.explorer.rollout_model.enable_history:
127+
history_experiences = self.model_wrapper.extract_experience_from_history(
128+
clear_history=False
129+
)
130+
self.assertEqual(len(history_experiences), len(generate_results))
131+
for exp, history_exp in zip(generate_results, history_experiences):
132+
self.assertEqual(exp.response_text, history_exp.response_text)
133+
self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist())
134+
self.assertEqual(exp.prompt_length, history_exp.prompt_length)
135+
self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist())
136+
else:
137+
with self.assertRaises(ValueError):
138+
self.model_wrapper.extract_experience_from_history(clear_history=False)
91139
messages = [
92140
{"role": "system", "content": "You are a helpful assistant."},
93141
{"role": "user", "content": "What's the weather like today?"},
@@ -97,15 +145,32 @@ def test_generate(self):
97145
},
98146
{"role": "user", "content": "OK, thanks!"},
99147
]
100-
results = self.model_wrapper.chat(messages, n=n, temperature=1.0)
148+
if self.use_async:
149+
results = await self.model_wrapper.chat_async(messages, n=n, temperature=1.0)
150+
else:
151+
results = self.model_wrapper.chat(messages, n=n, temperature=1.0)
101152
self.assertEqual(len(results), n)
153+
if self.config.explorer.rollout_model.enable_history:
154+
history_experiences = self.model_wrapper.extract_experience_from_history()
155+
self.assertEqual(len(history_experiences) - len(generate_results), len(results))
156+
for exp, history_exp in zip(results, history_experiences[len(generate_results) :]):
157+
self.assertEqual(exp.response_text, history_exp.response_text)
158+
self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist())
159+
self.assertEqual(exp.prompt_length, history_exp.prompt_length)
160+
self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist())
102161
for result in results:
103162
input_logprobs = result.logprobs[: result.prompt_length]
104163
output_logprobs = result.logprobs[result.prompt_length :]
105164
self.assertTrue(torch.all(input_logprobs == 0))
106165
self.assertTrue(torch.any(output_logprobs != 0))
107-
logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist())
166+
if self.use_async:
167+
logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist())
168+
else:
169+
logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist())
108170
self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0])
171+
if self.config.explorer.rollout_model.enable_history:
172+
history_experiences = self.model_wrapper.extract_experience_from_history()
173+
self.assertTrue(len(history_experiences) == 0)
109174
messages.append(
110175
{
111176
"role": "assistant",
@@ -128,84 +193,9 @@ def test_generate(self):
128193
self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask))
129194
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
130195
self.assertRaises(ValueError, self.model_wrapper.get_openai_client)
131-
132-
133-
class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase):
134-
def setUp(self):
135-
self.config = get_template_config()
136-
self.config.mode = "explore"
137-
self.config.model.model_path = get_model_path()
138-
self.config.explorer.rollout_model.engine_type = "vllm"
139-
self.config.explorer.rollout_model.tensor_parallel_size = 1
140-
self.config.explorer.rollout_model.engine_num = 2
141-
self.config.explorer.rollout_model.use_v1 = False
142-
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
143-
self.config.algorithm.repeat_times = 2
144-
self.config.check_and_update()
145-
self.engines, self.auxiliary_engines = create_inference_models(self.config)
146-
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")
147-
148-
149-
class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase):
150-
def setUp(self):
151-
self.config = get_template_config()
152-
self.config.mode = "explore"
153-
self.config.model.model_path = get_model_path()
154-
self.config.explorer.rollout_model.engine_type = "vllm_async"
155-
self.config.explorer.rollout_model.engine_num = 2
156-
self.config.explorer.rollout_model.tensor_parallel_size = 1
157-
self.config.explorer.rollout_model.use_v1 = False
158-
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
159-
self.config.algorithm.repeat_times = 2
160-
self.config.check_and_update()
161-
self.engines, self.auxiliary_engines = create_inference_models(self.config)
162-
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
163-
164-
165-
class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase):
166-
def setUp(self):
167-
self.config = get_template_config()
168-
self.config.mode = "explore"
169-
self.config.model.model_path = get_model_path()
170-
self.config.explorer.rollout_model.engine_type = "vllm_async"
171-
self.config.explorer.rollout_model.engine_num = 2
172-
self.config.explorer.rollout_model.tensor_parallel_size = 2
173-
self.config.explorer.rollout_model.use_v1 = False
174-
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
175-
self.config.check_and_update()
176-
self.engines, self.auxiliary_engines = create_inference_models(self.config)
177-
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
178-
179-
180-
class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase):
181-
def setUp(self):
182-
self.config = get_template_config()
183-
self.config.mode = "explore"
184-
self.config.model.model_path = get_model_path()
185-
self.config.explorer.rollout_model.engine_type = "vllm_async"
186-
self.config.explorer.rollout_model.engine_num = 2
187-
self.config.explorer.rollout_model.tensor_parallel_size = 2
188-
self.config.explorer.rollout_model.use_v1 = True
189-
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
190-
self.config.algorithm.repeat_times = 2
191-
self.config.check_and_update()
192-
self.engines, self.auxiliary_engines = create_inference_models(self.config)
193-
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
194-
195-
196-
class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase):
197-
def setUp(self):
198-
self.config = get_template_config()
199-
self.config.mode = "explore"
200-
self.config.model.model_path = get_model_path()
201-
self.config.explorer.rollout_model.engine_type = "vllm_async"
202-
self.config.explorer.rollout_model.engine_num = 2
203-
self.config.explorer.rollout_model.tensor_parallel_size = 1
204-
self.config.explorer.rollout_model.use_v1 = True
205-
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
206-
self.config.check_and_update()
207-
self.engines, self.auxiliary_engines = create_inference_models(self.config)
208-
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
196+
if self.config.explorer.rollout_model.enable_history:
197+
history_experiences = self.model_wrapper.extract_experience_from_history()
198+
self.assertTrue(len(history_experiences) == 0)
209199

210200

211201
class TestAPIServer(RayUnittestBase):
@@ -221,21 +211,25 @@ def setUp(self):
221211
self.config.explorer.rollout_model.enable_openai_api = True
222212
self.config.check_and_update()
223213
self.engines, self.auxiliary_engines = create_inference_models(self.config)
224-
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
214+
self.model_wrapper = ModelWrapper(
215+
self.engines[0], model_type="vllm_async", enable_history=True
216+
)
217+
self.model_wrapper_no_history = ModelWrapper(
218+
self.engines[0], model_type="vllm_async", enable_history=False
219+
)
225220

226221
def test_api(self):
227222
openai_client = self.model_wrapper.get_openai_client()
228223
messages = [
229224
{"role": "system", "content": "You are a helpful assistant."},
230225
{"role": "user", "content": "What is your name?"},
231226
]
232-
response = openai_client.chat.completions.create(
233-
model=self.config.model.model_path, messages=messages, n=1
234-
)
227+
model_id = openai_client.models.list().data[0].id
228+
response = openai_client.chat.completions.create(model=model_id, messages=messages, n=1)
235229
self.assertEqual(1, len(response.choices))
236230
self.assertTrue(len(response.choices[0].message.content) > 0)
237231
response = openai_client.chat.completions.create(
238-
model=self.config.model.model_path,
232+
model=model_id,
239233
messages=messages,
240234
n=2,
241235
temperature=0.5,
@@ -246,6 +240,32 @@ def test_api(self):
246240
self.assertTrue(response.choices[0].logprobs is not None)
247241
self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs))
248242
self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0)
243+
self.assertTrue(hasattr(response, "prompt_token_ids"))
244+
self.assertTrue(len(response.prompt_token_ids) > 0)
245+
self.assertTrue(hasattr(response.choices[0], "token_ids"))
246+
self.assertTrue(len(response.choices[0].token_ids) > 0)
247+
exps = self.model_wrapper.extract_experience_from_history()
248+
self.assertEqual(len(exps), 3)
249+
response = openai_client.chat.completions.create(
250+
model=model_id,
251+
messages=messages,
252+
n=4,
253+
temperature=0.5,
254+
logprobs=True,
255+
top_logprobs=0,
256+
)
257+
exps = self.model_wrapper.extract_experience_from_history()
258+
self.assertEqual(len(exps), 4)
259+
self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0)
260+
response = self.model_wrapper_no_history.get_openai_client().chat.completions.create(
261+
model=model_id, messages=messages, n=2
262+
)
263+
self.assertEqual(2, len(response.choices))
264+
self.assertTrue(hasattr(response.choices[0], "token_ids"))
265+
self.assertTrue(len(response.choices[0].token_ids) > 0)
266+
with self.assertRaises(ValueError):
267+
self.model_wrapper_no_history.extract_experience_from_history()
268+
self.assertEqual(len(self.model_wrapper_no_history.history), 0)
249269

250270

251271
class TestTokenizer(unittest.TestCase):

tests/explorer/scheduler_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import time
33
import unittest
4-
from typing import List, Tuple
4+
from typing import List
55

66
import ray
77
import torch
@@ -98,8 +98,8 @@ def init_process_group(
9898
def has_api_server(self) -> bool:
9999
return True
100100

101-
def api_server_ready(self) -> Tuple[str, str]:
102-
return "http://localhosts:12345", "placeholder"
101+
def api_server_ready(self) -> str:
102+
return "http://localhosts:12345"
103103

104104

105105
def generate_tasks(

trinity/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# -*- coding: utf-8 -*-
22
"""Trinity-RFT (Reinforcement Fine-Tuning)"""
33

4-
__version__ = "0.2.0"
4+
__version__ = "0.2.1.dev0"

trinity/algorithm/sample_strategy/mix_sample_strategy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66
import torch
7-
from verl.trainer.ppo.ray_trainer import DataProto
87

98
from trinity.algorithm.sample_strategy.sample_strategy import (
109
SAMPLE_STRATEGY,
@@ -85,7 +84,9 @@ def default_args(cls) -> Dict:
8584
}
8685

8786

88-
def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto:
87+
def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor):
88+
from verl.trainer.ppo.ray_trainer import DataProto
89+
8990
attention_mask = experiences.attention_masks
9091
cumsum = torch.cumsum(attention_mask, dim=-1)
9192
position_ids = torch.clip(cumsum - 1, 0, None).long()

trinity/common/config.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ class InferenceModelConfig:
206206
# For Qwen3
207207
enable_thinking: bool = False
208208

209+
# For history recording
210+
enable_history: bool = False
211+
209212
# For OpenAI API
210213
enable_openai_api: bool = False
211214

@@ -313,8 +316,6 @@ class ExplorerConfig:
313316
name: str = EXPLORER_NAME
314317
# for workflow runner
315318
# number of workflow runners.
316-
# For sync engine (vllm), it should be `1`.
317-
# For async engine (vllm_async), it could be a large number.
318319
runner_per_model: int = 8 # number of runners per each rollout model
319320
max_timeout: int = 1800 # wait each task for 30 minutes
320321
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout
@@ -722,11 +723,6 @@ def check_and_update(self) -> None: # noqa: C901
722723
self.model.critic_model_path = self.model.model_path
723724

724725
# check explorer
725-
if (
726-
self.explorer.rollout_model.engine_type != "vllm_async"
727-
and self.explorer.rollout_model.enable_openai_api
728-
):
729-
raise ValueError("OpenAI API server only support `vllm_async` engine.")
730726
if self.explorer.rollout_model.max_prompt_tokens is None:
731727
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
732728
if self.explorer.rollout_model.max_response_tokens is None:

0 commit comments

Comments
 (0)