Skip to content

Commit e3e1ad1

Browse files
authored
Async model support OpenAI compatible API (#41)
1 parent 03e97fc commit e3e1ad1

File tree

18 files changed

+491
-152
lines changed

18 files changed

+491
-152
lines changed

docs/sphinx_doc/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Welcome to Trinity-RFT's documentation!
1818

1919
tutorial/example_reasoning_basic.md
2020
tutorial/example_reasoning_advanced.md
21+
tutorial/example_async_mode.md
2122
tutorial/example_multi_turn.md
2223
tutorial/example_dpo.md
2324
tutorial/example_data_functionalities.md

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
This guide will introduce how to add new task types to Trinity-RFT and provide relevant development guidelines.
44

5-
> **Note**: Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
5+
```{note}
6+
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
7+
```
68

79
---
810

@@ -31,11 +33,11 @@ Before starting development, it's important to understand several core concepts:
3133

3234
### Step 1: Prepare Task Dataset
3335

34-
Each `Task` is a Python dictionary (`Dict[str, Any]`), containing various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario.
36+
Each `Task` contains various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario.
3537

3638
In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line’s JSON contains `question` and `answer` fields representing the problem description and standard answer, respectively.
3739

38-
```json
40+
```
3941
{"question": "1+1=", "answer": "2"}
4042
{"question": "2+2=", "answer": "4"}
4143
...
@@ -48,25 +50,45 @@ In the math problem scenario, the `Task` dataset can be a `jsonl` file, where ea
4850
The core of creating a new task type is writing a new `Workflow`, whose base class interface is as follows:
4951

5052
```python
51-
from abc import ABC
52-
from typing import List
53+
# import some packages
5354

5455
class Workflow(ABC):
5556

56-
def __init__(self, model: ModelWrapper, **kwargs):
57+
def __init__(
58+
self,
59+
model: ModelWrapper,
60+
task: Task,
61+
auxiliary_models: Optional[List[openai.OpenAI]] = None,
62+
):
5763
self.model = model
64+
self.auxiliary_models = auxiliary_models
5865

5966
@abstractmethod
6067
def run(self) -> List[Experience]:
6168
"""Run the workflow and return a list of Experiences."""
6269
```
6370

64-
Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflows`.
71+
Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflow` classes.
72+
73+
```python
74+
# import some packages
75+
from trinity.common.workflows.workflow import WORKFLOWS
76+
77+
@WORKFLOWS.register_module("my_workflow")
78+
class MyWorkflow(Workflow):
79+
pass
80+
```
6581

6682
#### Initialization Parameters
6783
When initializing, `Workflow` receives the following parameters:
68-
- `model`: Provides an API call interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
69-
- `kwargs`: Reads one line of data from the `Task` dataset, allowing developers to initialize internal modules such as Agent and Environment within the `Workflow` based on these parameters.
84+
- `model`: The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
85+
- `task`: An instance of `Task`, which is generated by one line of data from the `Task` dataset. The `raw_task` field contains the `Dict` format source data, which can be used to construct the `Workflow` instance.
86+
The `rollout_args` field contains the parameters for the rollout process, such as `n`, `temperature`, `top_k` and `top_p`.
87+
- `auxiliary_models`: A list of auxiliary models, which will not be trained. All of them provide OpenAI compatible API.
88+
89+
```{tip}
90+
The `model` also provided an OpenAI compatible API, you can switch to it by setting `explorer.enable_openai_api` to `true` in your config file and use `model.get_openai_client()` to get an `openai.OpenAI` instance.
91+
```
7092

7193
#### Example Code
7294
Below is a simple example demonstrating how to implement a math problem `Workflow`:
@@ -75,10 +97,16 @@ Below is a simple example demonstrating how to implement a math problem `Workflo
7597
@WORKFLOWS.register_module("example_workflow")
7698
class ExampleWorkflow(Workflow):
7799

78-
def __init__(self, model: ModelWrapper, **kwargs):
79-
super().__init__(model)
80-
self.question = kwargs.get("question")
81-
self.answer = kwargs.get("answer")
100+
def __init__(self, model: ModelWrapper, task: Task, **kwargs):
101+
super().__init__(model, **kwargs)
102+
self.question = task.raw_task.get("question")
103+
self.answer = task.raw_task.get("answer")
104+
105+
def calculate_reward(self, response: str, truth: str) -> float:
106+
if response == truth:
107+
return 1.0
108+
else:
109+
return 0.0
82110

83111
def run(self) -> List[Experience]:
84112
response = self.model.chat(
@@ -87,15 +115,19 @@ class ExampleWorkflow(Workflow):
87115
"role": "user",
88116
"content": f"Question:\n{self.question}",
89117
}
90-
]
118+
],
119+
n=self.task.rollout_args.repeat_times,
120+
temperature=self.task.rollout_args.temperature,
91121
)
92-
reward: float = calculate_reward(response.response_text, self.answer)
93-
return [Experience(
94-
tokens=response.tokens,
95-
prompt_length=response.prompt_length,
96-
reward=reward,
97-
logprobs=response.logprobs,
98-
)]
122+
reward: float = self.calculate_reward(response.response_text, self.answer)
123+
return [
124+
Experience(
125+
tokens=response.tokens,
126+
prompt_length=response.prompt_length,
127+
reward=reward,
128+
logprobs=response.logprobs,
129+
)
130+
]
99131
```
100132

101133
---

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies = [
3636
"flask",
3737
"requests",
3838
"tensorboard",
39+
"openai",
3940
]
4041

4142
[project.scripts]

tests/common/vllm_test.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from transformers import AutoTokenizer
66

77
from tests.tools import RayUnittestBase, get_template_config
8-
from trinity.common.models import create_rollout_models
8+
from trinity.common.models import create_inference_models
99
from trinity.common.models.model import ModelWrapper
1010
from trinity.common.models.utils import (
1111
tokenize_and_mask_messages_default,
@@ -127,6 +127,7 @@ def test_generate(self):
127127
)
128128
self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask))
129129
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
130+
self.assertRaises(ValueError, self.model_wrapper.get_openai_client)
130131

131132

132133
class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase):
@@ -139,7 +140,7 @@ def setUp(self):
139140
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
140141
self.config.explorer.use_v1 = False
141142
self.config.explorer.chat_template = CHAT_TEMPLATE
142-
self.engines = create_rollout_models(self.config)
143+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
143144
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")
144145

145146

@@ -153,7 +154,7 @@ def setUp(self):
153154
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
154155
self.config.explorer.use_v1 = False
155156
self.config.explorer.chat_template = CHAT_TEMPLATE
156-
self.engines = create_rollout_models(self.config)
157+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
157158
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
158159

159160

@@ -166,7 +167,7 @@ def setUp(self):
166167
self.config.explorer.tensor_parallel_size = 2
167168
self.config.explorer.use_v1 = False
168169
self.config.explorer.chat_template = CHAT_TEMPLATE
169-
self.engines = create_rollout_models(self.config)
170+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
170171
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
171172

172173

@@ -180,7 +181,7 @@ def setUp(self):
180181
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
181182
self.config.explorer.use_v1 = True
182183
self.config.explorer.chat_template = CHAT_TEMPLATE
183-
self.engines = create_rollout_models(self.config)
184+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
184185
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
185186

186187

@@ -193,10 +194,48 @@ def setUp(self):
193194
self.config.explorer.tensor_parallel_size = 1
194195
self.config.explorer.use_v1 = True
195196
self.config.explorer.chat_template = CHAT_TEMPLATE
196-
self.engines = create_rollout_models(self.config)
197+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
197198
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
198199

199200

201+
class TestAPIServer(RayUnittestBase):
202+
def setUp(self):
203+
self.config = get_template_config()
204+
self.config.model.model_path = get_model_path()
205+
self.config.explorer.engine_type = "vllm_async"
206+
self.config.explorer.engine_num = 1
207+
self.config.explorer.tensor_parallel_size = 1
208+
self.config.explorer.use_v1 = True
209+
self.config.explorer.chat_template = CHAT_TEMPLATE
210+
self.config.explorer.enable_openai_api = True
211+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
212+
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
213+
214+
def test_api(self):
215+
openai_client = self.model_wrapper.get_openai_client()
216+
messages = [
217+
{"role": "system", "content": "You are a helpful assistant."},
218+
{"role": "user", "content": "What is your name?"},
219+
]
220+
response = openai_client.chat.completions.create(
221+
model=self.config.model.model_path, messages=messages, n=1
222+
)
223+
self.assertEqual(1, len(response.choices))
224+
self.assertTrue(len(response.choices[0].message.content) > 0)
225+
response = openai_client.chat.completions.create(
226+
model=self.config.model.model_path,
227+
messages=messages,
228+
n=2,
229+
temperature=0.5,
230+
logprobs=True,
231+
top_logprobs=0,
232+
)
233+
self.assertEqual(2, len(response.choices))
234+
self.assertTrue(response.choices[0].logprobs is not None)
235+
self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs))
236+
self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0)
237+
238+
200239
class TestTokenizer(unittest.TestCase):
201240
def test_assistant_token_mask(self):
202241
messages = [

tests/explorer/runner_pool_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@WORKFLOWS.register_module("dummy_workflow")
2323
class DummyWorkflow(Workflow):
24-
def __init__(self, model, task):
24+
def __init__(self, model, task, auxiliary_models):
2525
super().__init__(model, task)
2626
self.error_type = task.task_desc
2727
self.seconds = None

trinity/common/config.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,25 @@ class ModelConfig:
137137
enable_thinking: bool = False
138138

139139

140+
@dataclass
141+
class InferenceModelConfig:
142+
# TODO: support setting engine_num
143+
model_path: str = ""
144+
tensor_parallel_size: int = 1
145+
use_v1: bool = True
146+
max_prompt_tokens: int = 2048
147+
max_response_tokens: int = 2048
148+
enable_thinking: bool = False
149+
enforce_eager: bool = True
150+
enable_prefix_caching: bool = False
151+
enable_chunked_prefill: bool = False
152+
gpu_memory_utilization: float = 0.9
153+
dtype: str = "bfloat16"
154+
seed: int = 42
155+
chat_template: Optional[str] = None
156+
bundle_indices: str = "" # DO NOT SET this field
157+
158+
140159
@dataclass
141160
class ClusterConfig:
142161
"""Config for the cluster."""
@@ -185,10 +204,10 @@ class BufferConfig:
185204
class ExplorerConfig:
186205
"""Config for explorer."""
187206

188-
# inference engine type, `vllm` or `vllm_async`
189-
engine_type: str = "vllm"
207+
# rollout engine type, `vllm` or `vllm_async`
208+
engine_type: str = "vllm_async"
190209

191-
# number of inference engines
210+
# number of rollout engines
192211
engine_num: int = 1
193212

194213
# number of workflow runners.
@@ -199,7 +218,8 @@ class ExplorerConfig:
199218
# for rollout tokneize
200219
chat_template: Optional[str] = None
201220

202-
# for vLLM
221+
# TODO: move vllm rollout model related args into
222+
# `explorer.rollout_model: InferenceModelConfig`
203223
tensor_parallel_size: int = 1
204224
enable_prefix_caching: bool = False
205225
enforce_eager: bool = True
@@ -210,6 +230,7 @@ class ExplorerConfig:
210230
gpu_memory_utilization: float = 0.9
211231
enable_chunked_prefill: bool = False
212232
use_v1: bool = True
233+
enable_openai_api: bool = False
213234
bundle_indices: str = "" # DO NOT SET this field
214235

215236
# for workflow runner
@@ -218,6 +239,9 @@ class ExplorerConfig:
218239
max_timeout: int = 900 # wait each task for 15 minutes
219240
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout
220241

242+
# for other models used in the custom workflows
243+
auxiliary_models: List[InferenceModelConfig] = field(default_factory=list)
244+
221245

222246
@dataclass
223247
class TrainerConfig:
@@ -453,6 +477,10 @@ def check_and_update(self) -> None: # noqa: C901
453477
if not self.model.critic_model_path:
454478
self.model.critic_model_path = self.model.model_path
455479

480+
# check explorer
481+
if self.explorer.engine_type != "vllm_asyc" and self.explorer.enable_openai_api:
482+
raise ValueError("OpenAI API server only support `vllm_async` engine.")
483+
456484
# check synchronizer
457485
self.synchronizer.explorer_world_size = (
458486
self.explorer.engine_num * self.explorer.tensor_parallel_size

0 commit comments

Comments
 (0)