Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/sphinx_doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Welcome to Trinity-RFT's documentation!

tutorial/example_reasoning_basic.md
tutorial/example_reasoning_advanced.md
tutorial/example_async_mode.md
tutorial/example_multi_turn.md
tutorial/example_dpo.md
tutorial/example_data_functionalities.md
Expand Down
74 changes: 53 additions & 21 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

This guide will introduce how to add new task types to Trinity-RFT and provide relevant development guidelines.

> **Note**: Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
```{note}
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
```

---

Expand Down Expand Up @@ -31,11 +33,11 @@ Before starting development, it's important to understand several core concepts:

### Step 1: Prepare Task Dataset

Each `Task` is a Python dictionary (`Dict[str, Any]`), containing various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario.
Each `Task` contains various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario.

In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line’s JSON contains `question` and `answer` fields representing the problem description and standard answer, respectively.

```json
```
{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...
Expand All @@ -48,25 +50,45 @@ In the math problem scenario, the `Task` dataset can be a `jsonl` file, where ea
The core of creating a new task type is writing a new `Workflow`, whose base class interface is as follows:

```python
from abc import ABC
from typing import List
# import some packages

class Workflow(ABC):

def __init__(self, model: ModelWrapper, **kwargs):
def __init__(
self,
model: ModelWrapper,
task: Task,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
self.model = model
self.auxiliary_models = auxiliary_models

@abstractmethod
def run(self) -> List[Experience]:
"""Run the workflow and return a list of Experiences."""
```

Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflows`.
Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflow` classes.

```python
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("my_workflow")
class MyWorkflow(Workflow):
pass
```

#### Initialization Parameters
When initializing, `Workflow` receives the following parameters:
- `model`: Provides an API call interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `kwargs`: Reads one line of data from the `Task` dataset, allowing developers to initialize internal modules such as Agent and Environment within the `Workflow` based on these parameters.
- `model`: The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `task`: An instance of `Task`, which is generated by one line of data from the `Task` dataset. The `raw_task` field contains the `Dict` format source data, which can be used to construct the `Workflow` instance.
The `rollout_args` field contains the parameters for the rollout process, such as `n`, `temperature`, `top_k` and `top_p`.
- `auxiliary_models`: A list of auxiliary models, which will not be trained. All of them provide OpenAI compatible API.

```{tip}
The `model` also provided an OpenAI compatible API, you can switch to it by setting `explorer.enable_openai_api` to `true` in your config file and use `model.get_openai_client()` to get an `openai.OpenAI` instance.
```

#### Example Code
Below is a simple example demonstrating how to implement a math problem `Workflow`:
Expand All @@ -75,10 +97,16 @@ Below is a simple example demonstrating how to implement a math problem `Workflo
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

def __init__(self, model: ModelWrapper, **kwargs):
super().__init__(model)
self.question = kwargs.get("question")
self.answer = kwargs.get("answer")
def __init__(self, model: ModelWrapper, task: Task, **kwargs):
super().__init__(model, **kwargs)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")

def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0

def run(self) -> List[Experience]:
response = self.model.chat(
Expand All @@ -87,15 +115,19 @@ class ExampleWorkflow(Workflow):
"role": "user",
"content": f"Question:\n{self.question}",
}
]
],
n=self.task.rollout_args.repeat_times,
temperature=self.task.rollout_args.temperature,
)
reward: float = calculate_reward(response.response_text, self.answer)
return [Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)]
reward: float = self.calculate_reward(response.response_text, self.answer)
return [
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
]
```

---
Expand Down
53 changes: 47 additions & 6 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import AutoTokenizer

from tests.tools import RayUnittestBase, get_template_config
from trinity.common.models import create_rollout_models
from trinity.common.models import create_inference_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.utils import (
tokenize_and_mask_messages_default,
Expand Down Expand Up @@ -127,6 +127,7 @@ def test_generate(self):
)
self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask))
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
self.assertRaises(ValueError, self.model_wrapper.get_openai_client)


class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase):
Expand All @@ -139,7 +140,7 @@ def setUp(self):
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")


Expand All @@ -153,7 +154,7 @@ def setUp(self):
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


Expand All @@ -166,7 +167,7 @@ def setUp(self):
self.config.explorer.tensor_parallel_size = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


Expand All @@ -180,7 +181,7 @@ def setUp(self):
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = True
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


Expand All @@ -193,10 +194,50 @@ def setUp(self):
self.config.explorer.tensor_parallel_size = 1
self.config.explorer.use_v1 = True
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


class TestAPIServer(RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 1
self.config.explorer.tensor_parallel_size = 1
self.config.explorer.use_v1 = True
self.config.explorer.chat_template = CHAT_TEMPLATE
self.config.explorer.enable_openai_api = True
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")

def test_api(self):
openai_client = self.model_wrapper.get_openai_client()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
]
response = openai_client.chat.completions.create(
model=self.config.model.model_path, messages=messages, n=1
)
print(response)
self.assertEqual(1, len(response.choices))
self.assertTrue(len(response.choices[0].message.content) > 0)
response = openai_client.chat.completions.create(
model=self.config.model.model_path,
messages=messages,
n=2,
temperature=0.5,
logprobs=True,
top_logprobs=0,
)
print(response)
self.assertEqual(2, len(response.choices))
self.assertTrue(response.choices[0].logprobs is not None)
self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs))
self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0)


class TestTokenizer(unittest.TestCase):
def test_assistant_token_mask(self):
messages = [
Expand Down
2 changes: 1 addition & 1 deletion tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@WORKFLOWS.register_module("dummy_workflow")
class DummyWorkflow(Workflow):
def __init__(self, model, task):
def __init__(self, model, task, auxiliary_models):
super().__init__(model, task)
self.error_type = task.task_desc
self.seconds = None
Expand Down
36 changes: 32 additions & 4 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,25 @@ class ModelConfig:
enable_thinking: bool = False


@dataclass
class InferenceModelConfig:
# TODO: support setting engine_num
model_path: str = ""
tensor_parallel_size: int = 1
use_v1: bool = True
max_prompt_tokens: int = 2048
max_response_tokens: int = 2048
enable_thinking: bool = False
enforce_eager: bool = True
enable_prefix_caching: bool = False
enable_chunked_prefill: bool = False
gpu_memory_utilization: float = 0.9
dtype: str = "bfloat16"
seed: int = 42
chat_template: Optional[str] = None
bundle_indices: str = "" # DO NOT SET this field


@dataclass
class ClusterConfig:
"""Config for the cluster."""
Expand Down Expand Up @@ -185,10 +204,10 @@ class BufferConfig:
class ExplorerConfig:
"""Config for explorer."""

# inference engine type, `vllm` or `vllm_async`
engine_type: str = "vllm"
# rollout engine type, `vllm` or `vllm_async`
engine_type: str = "vllm_async"

# number of inference engines
# number of rollout engines
engine_num: int = 1

# number of workflow runners.
Expand All @@ -199,7 +218,8 @@ class ExplorerConfig:
# for rollout tokneize
chat_template: Optional[str] = None

# for vLLM
# TODO: move vllm rollout model related args into
# `explorer.rollout_model: InferenceModelConfig`
tensor_parallel_size: int = 1
enable_prefix_caching: bool = False
enforce_eager: bool = True
Expand All @@ -210,6 +230,7 @@ class ExplorerConfig:
gpu_memory_utilization: float = 0.9
enable_chunked_prefill: bool = False
use_v1: bool = True
enable_openai_api: bool = False
bundle_indices: str = "" # DO NOT SET this field

# for workflow runner
Expand All @@ -218,6 +239,9 @@ class ExplorerConfig:
max_timeout: int = 900 # wait each task for 15 minutes
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout

# for other models used in the custom workflows
auxiliary_models: List[InferenceModelConfig] = field(default_factory=list)


@dataclass
class TrainerConfig:
Expand Down Expand Up @@ -453,6 +477,10 @@ def check_and_update(self) -> None: # noqa: C901
if not self.model.critic_model_path:
self.model.critic_model_path = self.model.model_path

# check explorer
if self.explorer.engine_type != "vllm_asyc" and self.explorer.enable_openai_api:
raise ValueError("OpenAI API server only support `vllm_async` engine.")

# check synchronizer
self.synchronizer.explorer_world_size = (
self.explorer.engine_num * self.explorer.tensor_parallel_size
Expand Down
Loading