Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/sphinx_doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Welcome to Trinity-RFT's documentation!

tutorial/example_reasoning_basic.md
tutorial/example_reasoning_advanced.md
tutorial/example_async_mode.md
tutorial/example_multi_turn.md
tutorial/example_dpo.md
tutorial/example_data_functionalities.md
Expand Down
74 changes: 53 additions & 21 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

This guide will introduce how to add new task types to Trinity-RFT and provide relevant development guidelines.

> **Note**: Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
```{note}
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
```

---

Expand Down Expand Up @@ -31,11 +33,11 @@ Before starting development, it's important to understand several core concepts:

### Step 1: Prepare Task Dataset

Each `Task` is a Python dictionary (`Dict[str, Any]`), containing various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario.
Each `Task` contains various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario.

In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line’s JSON contains `question` and `answer` fields representing the problem description and standard answer, respectively.

```json
```
{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...
Expand All @@ -48,25 +50,45 @@ In the math problem scenario, the `Task` dataset can be a `jsonl` file, where ea
The core of creating a new task type is writing a new `Workflow`, whose base class interface is as follows:

```python
from abc import ABC
from typing import List
# import some packages

class Workflow(ABC):

def __init__(self, model: ModelWrapper, **kwargs):
def __init__(
self,
model: ModelWrapper,
task: Task,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
self.model = model
self.auxiliary_models = auxiliary_models

@abstractmethod
def run(self) -> List[Experience]:
"""Run the workflow and return a list of Experiences."""
```

Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflows`.
Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflow` classes.

```python
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("my_workflow")
class MyWorkflow(Workflow):
pass
```

#### Initialization Parameters
When initializing, `Workflow` receives the following parameters:
- `model`: Provides an API call interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `kwargs`: Reads one line of data from the `Task` dataset, allowing developers to initialize internal modules such as Agent and Environment within the `Workflow` based on these parameters.
- `model`: The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `task`: An instance of `Task`, which is generated by one line of data from the `Task` dataset. The `raw_task` field contains the `Dict` format source data, which can be used to construct the `Workflow` instance.
The `rollout_args` field contains the parameters for the rollout process, such as `n`, `temperature`, `top_k` and `top_p`.
- `auxiliary_models`: A list of auxiliary models, which will not be trained. All of them provide OpenAI compatible API.

```{tip}
The `model` also provided an OpenAI compatible API, you can switch to it by setting `explorer.enable_openai_api` to `true` in your config file and use `model.get_openai_client()` to get an `openai.OpenAI` instance.
```

#### Example Code
Below is a simple example demonstrating how to implement a math problem `Workflow`:
Expand All @@ -75,10 +97,16 @@ Below is a simple example demonstrating how to implement a math problem `Workflo
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

def __init__(self, model: ModelWrapper, **kwargs):
super().__init__(model)
self.question = kwargs.get("question")
self.answer = kwargs.get("answer")
def __init__(self, model: ModelWrapper, task: Task, **kwargs):
super().__init__(model, **kwargs)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")

def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0

def run(self) -> List[Experience]:
response = self.model.chat(
Expand All @@ -87,15 +115,19 @@ class ExampleWorkflow(Workflow):
"role": "user",
"content": f"Question:\n{self.question}",
}
]
],
n=self.task.rollout_args.repeat_times,
temperature=self.task.rollout_args.temperature,
)
reward: float = calculate_reward(response.response_text, self.answer)
return [Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)]
reward: float = self.calculate_reward(response.response_text, self.answer)
return [
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
]
```

---
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"flask",
"requests",
"tensorboard",
"openai",
]

[project.scripts]
Expand Down
51 changes: 45 additions & 6 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import AutoTokenizer

from tests.tools import RayUnittestBase, get_template_config
from trinity.common.models import create_rollout_models
from trinity.common.models import create_inference_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.utils import (
tokenize_and_mask_messages_default,
Expand Down Expand Up @@ -127,6 +127,7 @@ def test_generate(self):
)
self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask))
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
self.assertRaises(ValueError, self.model_wrapper.get_openai_client)


class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase):
Expand All @@ -139,7 +140,7 @@ def setUp(self):
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")


Expand All @@ -153,7 +154,7 @@ def setUp(self):
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


Expand All @@ -166,7 +167,7 @@ def setUp(self):
self.config.explorer.tensor_parallel_size = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


Expand All @@ -180,7 +181,7 @@ def setUp(self):
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = True
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


Expand All @@ -193,10 +194,48 @@ def setUp(self):
self.config.explorer.tensor_parallel_size = 1
self.config.explorer.use_v1 = True
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


class TestAPIServer(RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 1
self.config.explorer.tensor_parallel_size = 1
self.config.explorer.use_v1 = True
self.config.explorer.chat_template = CHAT_TEMPLATE
self.config.explorer.enable_openai_api = True
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")

def test_api(self):
openai_client = self.model_wrapper.get_openai_client()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
]
response = openai_client.chat.completions.create(
model=self.config.model.model_path, messages=messages, n=1
)
self.assertEqual(1, len(response.choices))
self.assertTrue(len(response.choices[0].message.content) > 0)
response = openai_client.chat.completions.create(
model=self.config.model.model_path,
messages=messages,
n=2,
temperature=0.5,
logprobs=True,
top_logprobs=0,
)
self.assertEqual(2, len(response.choices))
self.assertTrue(response.choices[0].logprobs is not None)
self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs))
self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0)


class TestTokenizer(unittest.TestCase):
def test_assistant_token_mask(self):
messages = [
Expand Down
2 changes: 1 addition & 1 deletion tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@WORKFLOWS.register_module("dummy_workflow")
class DummyWorkflow(Workflow):
def __init__(self, model, task):
def __init__(self, model, task, auxiliary_models):
super().__init__(model, task)
self.error_type = task.task_desc
self.seconds = None
Expand Down
36 changes: 32 additions & 4 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,25 @@ class ModelConfig:
enable_thinking: bool = False


@dataclass
class InferenceModelConfig:
# TODO: support setting engine_num
model_path: str = ""
tensor_parallel_size: int = 1
use_v1: bool = True
max_prompt_tokens: int = 2048
max_response_tokens: int = 2048
enable_thinking: bool = False
enforce_eager: bool = True
enable_prefix_caching: bool = False
enable_chunked_prefill: bool = False
gpu_memory_utilization: float = 0.9
dtype: str = "bfloat16"
seed: int = 42
chat_template: Optional[str] = None
bundle_indices: str = "" # DO NOT SET this field


@dataclass
class ClusterConfig:
"""Config for the cluster."""
Expand Down Expand Up @@ -185,10 +204,10 @@ class BufferConfig:
class ExplorerConfig:
"""Config for explorer."""

# inference engine type, `vllm` or `vllm_async`
engine_type: str = "vllm"
# rollout engine type, `vllm` or `vllm_async`
engine_type: str = "vllm_async"

# number of inference engines
# number of rollout engines
engine_num: int = 1

# number of workflow runners.
Expand All @@ -199,7 +218,8 @@ class ExplorerConfig:
# for rollout tokneize
chat_template: Optional[str] = None

# for vLLM
# TODO: move vllm rollout model related args into
# `explorer.rollout_model: InferenceModelConfig`
tensor_parallel_size: int = 1
enable_prefix_caching: bool = False
enforce_eager: bool = True
Expand All @@ -210,6 +230,7 @@ class ExplorerConfig:
gpu_memory_utilization: float = 0.9
enable_chunked_prefill: bool = False
use_v1: bool = True
enable_openai_api: bool = False
bundle_indices: str = "" # DO NOT SET this field

# for workflow runner
Expand All @@ -218,6 +239,9 @@ class ExplorerConfig:
max_timeout: int = 900 # wait each task for 15 minutes
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout

# for other models used in the custom workflows
auxiliary_models: List[InferenceModelConfig] = field(default_factory=list)


@dataclass
class TrainerConfig:
Expand Down Expand Up @@ -453,6 +477,10 @@ def check_and_update(self) -> None: # noqa: C901
if not self.model.critic_model_path:
self.model.critic_model_path = self.model.model_path

# check explorer
if self.explorer.engine_type != "vllm_asyc" and self.explorer.enable_openai_api:
raise ValueError("OpenAI API server only support `vllm_async` engine.")

# check synchronizer
self.synchronizer.explorer_world_size = (
self.explorer.engine_num * self.explorer.tensor_parallel_size
Expand Down
Loading