Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
run: |
REPORT=report.json
if [ -f "$REPORT" ]; then
jq '(.results.tests[] | .start, .stop) |= (. * 1000) | (.results.summary.start, .results.summary.stop) |= (. * 1000)' "$REPORT" > "$REPORT.tmp" && mv "$REPORT.tmp" "$REPORT"
jq '(.results.summary.start, .results.summary.stop) |= (. * 1000)' "$REPORT" > "$REPORT.tmp" && mv "$REPORT.tmp" "$REPORT"
fi

- name: Clean checkpoint dir
Expand Down
72 changes: 66 additions & 6 deletions docs/sphinx_doc/source/tutorial/develop_operator.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ Developers can implement and use their own operators by following the steps belo

### Step 1: Implement Operator

The `ExperienceOperator` interface includes only one `process` method. The `ExperiencePipeline` will call this method with a list of `Experience` generated by the `Explorer` in one explore step. The `process` method should return a tuple containing the processed list of `Experience` and a dictionary of metrics for logging.
The `ExperienceOperatorV1` interface includes only one `process` method. The `ExperiencePipeline` will call this method with a list of `Experience` generated by the `Explorer` in one explore step. The `process` method should return a tuple containing the processed list of `Experience` and a dictionary of metrics for logging.

```python
class ExperienceOperator(ABC):
class ExperienceOperatorV1(ABC):

@abstractmethod
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
"""Process a list of experiences and return a transformed list.

Args:
Expand All @@ -40,16 +40,16 @@ class ExperienceOperator(ABC):
Here is an implementation of a simple operator that filters out experiences with rewards below a certain threshold:

```python
from trinity.buffer.operators import ExperienceOperator
from trinity.buffer.operators import ExperienceOperatorV1
from trinity.common.experience import Experience


class RewardFilter(ExperienceOperator):
class RewardFilter(ExperienceOperatorV1):

def __init__(self, threshold: float = 0.0) -> None:
self.threshold = threshold

def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
filtered_exps = [exp for exp in exps if exp.reward >= self.threshold]
metrics = {"filtered_count": len(exps) - len(filtered_exps)}
return filtered_exps, metrics
Expand Down Expand Up @@ -80,3 +80,63 @@ synchronizer:
The `RewardFilter` reduces the number of experiences, which may cause the trainer can't get enough experiences to start a training step. To avoid the issue, you can use the advanced {ref}`Dynamic Synchronization <Synchronizer>` feature provided by Trinity-RFT as shown in the above configuration file.
The above setting means that the `Explorer` will sync with the `Trainer` every 2 steps and will continue running regardless of how many steps the `Trainer` has completed. This ensures that the `Trainer` can always get enough experiences to start a training step as long as the `Explorer` is running.
```

### Advanced Features

#### Using Auxiliary Models inside Operators

As introduced in [Workflow Development Guide](develop_workflow.md), Trinity-RFT supports depolying auxiliary models and calling them through OpenAI API. This feature can also be used in operators, which allows you to leverage powerful models to judge and process experiences. This is particularly useful for implementing operators that require complex reasoning or natural language understanding.

Suppose you have the following auxiliary model configuration in the YAML file:

```yaml
explorer:
auxiliary_models:
- model_path: Qwen/Qwen2.5-32B-Instruct
name: qwen2.5-32B
engine_num: 1
tensor_parallel_size: 2
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
- model_path: Qwen/Qwen3-8B
name: qwen3-8B
engine_num: 2
tensor_parallel_size: 1
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
```

Trinity-RFT will automatically inject the deployed auxiliary models into the operators as `self.auxiliary_models`, which is a dictionary mapping model names to model instances (`Dict[str, List[openai.AsyncOpenAI]]`).
The key is the `name` of the model in the configuration file, and the number of each model instance in the list is determined by the `engine_num`.
You can call the model's inference API in the `process` method of the operator to get the model's response based on the experience data. Below is an example of how to use the auxiliary model in an operator:


```python
from trinity.buffer.operators import ExperienceOperatorV1
from trinity.common.experience import Experience


class OperatorWithModel(ExperienceOperatorV1):

async def judge_experience(self, exp: Experience) -> bool:
# Extract necessary information from the experience and prepare input for the model
# messages = ...
# Call the model's inference API to get the response
response = await self.auxiliary_models["qwen2.5-32B"][0].chat.completions.create(
model=self.auxiliary_models["qwen2.5-32B"][0].model_path, # Trinity-RFT will automatically set the `model_path` for easy calling
messages=messages,
)
# Process the model's response and update the experience accordingly
# ...
return exp

async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
# Use the model to process experiences
# For example, you can call the model's generate method to get responses based on the experience data
await asyncio.gather(*(self.judge_experience(exp) for exp in exps))
return exps, {}
```
69 changes: 63 additions & 6 deletions docs/sphinx_doc/source_zh/tutorial/develop_operator.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ Operator 模块负责处理由 Explorer 所生成的轨迹数据(我们称之

### 步骤 1:实现数据处理算子

`ExperienceOperator` 接口仅包含一个 `process` 方法。`ExperiencePipeline` 将调用此方法,传入 `Explorer` 在一次探索步骤中生成的一组 `Experience`。`process` 方法应返回一个元组,包含处理后的 `Experience` 列表和用于日志记录的指标字典。
`ExperienceOperatorV1` 接口仅包含一个 `process` 方法。`ExperiencePipeline` 将调用此方法,传入 `Explorer` 在一次探索步骤中生成的一组 `Experience`。`process` 方法应返回一个元组,包含处理后的 `Experience` 列表和用于日志记录的指标字典。

```python
class ExperienceOperator(ABC):
class ExperienceOperatorV1(ABC):

@abstractmethod
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
"""Process a list of experiences and return a transformed list.

Args:
Expand All @@ -41,16 +41,16 @@ class ExperienceOperator(ABC):
以下是一个简单数据处理算子的实现示例,该算子过滤掉奖励低于某一阈值的 experience:

```python
from trinity.buffer.operators import ExperienceOperator
from trinity.buffer.operators import ExperienceOperatorV1
from trinity.common.experience import Experience


class RewardFilter(ExperienceOperator):
class RewardFilter(ExperienceOperatorV1):

def __init__(self, threshold: float = 0.0) -> None:
self.threshold = threshold

def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
filtered_exps = [exp for exp in exps if exp.reward >= self.threshold]
metrics = {"filtered_count": len(exps) - len(filtered_exps)}
return filtered_exps, metrics
Expand Down Expand Up @@ -89,3 +89,60 @@ synchronizer:
`RewardFilter` 会减少 experience 数量,可能导致 Trainer 无法获得足够的 experience 来启动训练流程。为避免此问题,你可以使用 Trinity-RFT 提供的 {ref}`动态同步 <Synchronizer>` 功能 (`explorer_driven`)。
上述设置意味着 `Explorer` 每运行 2 步就会与 `Trainer` 同步一次,且无论 `Trainer` 当前完成了多少步都会继续运行。这确保了只要 `Explorer` 在运行,`Trainer` 就总能获得足够的 experience 来启动训练步骤。
```

### 进阶特性

#### 在 Operator 中使用辅助模型

如 [工作流开发指南](develop_workflow.md) 所介绍,Trinity-RFT 支持部署辅助模型并通过 OpenAI API 调用它们。该特性同样可以在 Operator 中使用,使你能够利用强大的模型对 experience 进行判断和处理。这对于实现需要复杂推理或自然语言理解的数据处理算子尤其有用。

假设你在 YAML 配置文件中有如下辅助模型配置:

```yaml
explorer:
auxiliary_models:
- model_path: Qwen/Qwen2.5-32B-Instruct
name: qwen2.5-32B
engine_num: 1
tensor_parallel_size: 2
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
- model_path: Qwen/Qwen3-8B
name: qwen3-8B
engine_num: 2
tensor_parallel_size: 1
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
```

Trinity-RFT 会自动将已部署的辅助模型以 `self.auxiliary_models` 的形式注入到 Operator 中。该属性是一个字典,键为配置文件中模型的 `name`,值为模型实例列表(`Dict[str, List[openai.AsyncOpenAI]]`),每个模型实例的数量由 `engine_num` 决定。

你可以在算子的 `process` 方法中调用模型的推理 API,根据 experience 数据获得模型的响应。下面是一个在 Operator 中使用辅助模型的示例:

```python
from trinity.buffer.operators import ExperienceOperatorV1
from trinity.common.experience import Experience


class OperatorWithModel(ExperienceOperatorV1):

async def judge_experience(self, exp: Experience) -> bool:
# 从 experience 中提取必要信息并准备模型输入
# messages = ...
# 调用模型推理 API 获取响应
response = await self.auxiliary_models["qwen2.5-32B"][0].chat.completions.create(
model=self.auxiliary_models["qwen2.5-32B"][0].model_path, # Trinity-RFT 会自动设置 model_path,便于调用
messages=messages,
)
# 处理模型响应并根据需要更新 experience
# ...
return exp

async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
await asyncio.gather(*(self.judge_experience(exp) for exp in exps))
return exps, {}
```
10 changes: 5 additions & 5 deletions tests/buffer/reward_shaping_mapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from trinity.buffer.pipelines.experience_pipeline import ExperienceOperator
from trinity.buffer.operators.experience_operator import create_operators
from trinity.common.config import OperatorConfig
from trinity.common.experience import EID, Experience

Expand All @@ -29,8 +29,8 @@ def get_experiences(task_num: int, repeat_times: int = 1, step_num: int = 1) ->
]


class TestRewardShapingMapper(unittest.TestCase):
def test_basic_usage(self):
class TestRewardShapingMapper(unittest.IsolatedAsyncioTestCase):
async def test_basic_usage(self):
# test input cache
op_configs = [
OperatorConfig(
Expand All @@ -51,7 +51,7 @@ def test_basic_usage(self):
},
)
]
ops = ExperienceOperator.create_operators(op_configs)
ops = create_operators(op_configs)
self.assertEqual(len(ops), 1)

op = ops[0]
Expand All @@ -61,7 +61,7 @@ def test_basic_usage(self):
experiences = get_experiences(
task_num=task_num, repeat_times=repeat_times, step_num=step_num
)
res_exps, metrics = op.process(deepcopy(experiences))
res_exps, metrics = await op.process(deepcopy(experiences))
self.assertEqual(len(res_exps), task_num * repeat_times * step_num)
self.assertIn("reward_diff/mean", metrics)
self.assertIn("reward_diff/min", metrics)
Expand Down
5 changes: 4 additions & 1 deletion tests/cli/launcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def setUp(self):
def tearDown(self):
sys.argv = self._orig_argv
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)
import trinity.utils.log as log

log._ray_logger_ctx.set(None)
log._ray_logger = None

@mock.patch("trinity.cli.launcher.serve")
@mock.patch("trinity.cli.launcher.explore")
Expand Down Expand Up @@ -262,7 +266,6 @@ def test_multi_stage_run(
"/path/to/hf/checkpoint",
)

@unittest.skip("TODO: fix")
@mock.patch("trinity.cli.launcher.load_config")
def test_debug_mode(self, mock_load):
process = multiprocessing.Process(target=debug_inference_model_process)
Expand Down
Loading