Skip to content
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ The configuration for each task dataset is defined as follows:
- `temperature`: The temperature for sampling.
- `default_workflow_type`: Type of workflow logic applied to this dataset. If not specified, the `buffer.default_workflow_type` is used.
- `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used.
- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.


### Trainer Input

Expand Down
150 changes: 149 additions & 1 deletion docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ To handle differences in `Task` contents, Trinity-RFT provides a unified `Task`
- **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.
- **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field.

```{tip}
`workflow`, `workflow_args` and `raw_task` provide different levels of customization.

- `workflow` provides the global settings for all tasks that uses the same workflow. (Global Level)
- `workflow_args` can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)
- `raw_task` provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)
```

In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example:

Expand Down Expand Up @@ -111,7 +120,7 @@ During initialization, `Workflow` receives the following parameters:
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
```

Heres an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.

```python
class ExampleWorkflow(Workflow):
Expand Down Expand Up @@ -188,6 +197,25 @@ class ExampleWorkflow(Workflow):
pass
```

For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`:

```python
# existing import lines
from .example_workflow import ExampleWorkflow

__all__ = [
# existing __all__ lines
"ExampleWorkflow",
]
```

For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in `trinity/plugins`. Trinity-RFT will automatically detect and load all custom modules in this folder.

```{tip}
You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `<Trinity_RFT_ROOT_DIR>/trinity/plugins` as the default directory.
```


#### Avoid Re-initialization

For heavy workflows, re-initializing every time can incurs extra computational costs.
Expand Down Expand Up @@ -286,6 +314,126 @@ trinity run --config <your_yaml_file>

---

## Adding New Config Entries for the Config Generator (Advanced)

### Step 0: Understanding Streamlit

Before adding new parameters to the Config Generator page, it is essential to familiarize yourself with the relevant API and mechanisms of [Streamlit](https://docs.streamlit.io/develop/api-reference). This project primarily utilizes various input components from Streamlit and employs `st.session_state` to store user-input parameters.

### Step 1: Implement New Config Entries

To illustrate the process of creating a new parameter setting for the Config Generator page, we will use `train_batch_size` as an example.

1. Determine the appropriate scope for the parameter. Currently, parameters are categorized into four files:
- `trinity/manager/config_registry/buffer_config_manager.py`
- `trinity/manager/config_registry/explorer_config_manager.py`
- `trinity/manager/config_registry/model_config_manager.py`
- `trinity/manager/config_registry/trainer_config_manager.py`

In this case, `train_batch_size` should be placed in the `buffer_config_manager.py` file.

2. Create a parameter setting function using Streamlit. The function name must follow the convention of starting with 'set_', and the remainder of the name becomes the config name.

3. Decorate the parameter setting function with the `CONFIG_GENERATORS.register_config` decorator. This decorator requires the following information:
- Default value of the parameter
- Visibility condition (if applicable)
- Additional config parameters (if needed)

```{note}
The `CONFIG_GENERATORS.register_config` decorator automatically passes `key=config_name` as an argument to the registered configuration function. Ensure that your function accepts this keyword argument.
```

For `train_batch_size`, we will use the following settings:
- Default value: 96
- Visibility condition: `lambda: st.session_state["trainer_gpu_num"] > 0`
- Additional config: `{"_train_batch_size_per_gpu": 16}`


Here's the complete code for the `train_batch_size` parameter:

```python
@CONFIG_GENERATORS.register_config(
default_value=96,
visible=lambda: st.session_state["trainer_gpu_num"] > 0,
other_configs={"_train_batch_size_per_gpu": 16},
)
def set_train_batch_size(**kwargs):
key = kwargs.get("key")
trainer_gpu_num = st.session_state["trainer_gpu_num"]
st.session_state[key] = (
st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"]
)

def on_change():
st.session_state["_train_batch_size_per_gpu"] = max(
st.session_state[key] // st.session_state["trainer_gpu_num"], 1
)

st.number_input(
"Train Batch Size",
min_value=trainer_gpu_num,
step=trainer_gpu_num,
help=_str_for_train_batch_size(),
on_change=on_change,
**kwargs,
)
```

If the parameter requires validation, create a check function. For `train_batch_size`, we need to ensure it is divisible by `trainer_gpu_num`. If not, a warning should be displayed, and the parameter should be added to `unfinished_fields`.

Decorate the check function with the `CONFIG_GENERATORS.register_check` decorator:

```python
@CONFIG_GENERATORS.register_check()
def check_train_batch_size(unfinished_fields: set, key: str):
if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0:
unfinished_fields.add(key)
st.warning(_str_for_train_batch_size())
```

```{note}
The `CONFIG_GENERATORS.register_check` decorator automatically receives `key=config_name` and `unfinished_fields=self.unfinished_fields` as arguments. Ensure your function accepts these keyword arguments.
```

### Step 2: Integrating New Parameters into `config_manager.py`

To successfully integrate new parameters into the `config_manager.py` file, please adhere to the following procedure:

1. Parameter Categorization:
Determine the appropriate section for the new parameter based on its functionality. The config generator page is structured into two primary modes:
- Beginner Mode: Comprises "Essential Configs" and "Important Configs" sections.
- Expert Mode: Includes "Model", "Buffer", "Explorer and Synchronizer", and "Trainer" sections.

2. Parameter Addition:
Incorporate the new parameter into the relevant section using the `self.get_configs` method within the `ConfigManager` class.

Example:
```python
class ConfigManager:
def _expert_buffer_part(self):
self.get_configs("total_epochs", "train_batch_size")
```

3. YAML File Integration:
Locate the appropriate position for the new parameter within the YAML file structure. This should be done in the `generate_config` function and its associated sub-functions.

4. Parameter Value Assignment:
Utilize `st.session_state` to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML.

Example:
```python
class ConfigManager:
def _gen_buffer_config(self):
buffer_config = {
"batch_size": st.session_state["train_batch_size"],
# Additional configuration parameters
}
```

By meticulously following these steps, you can ensure that new parameters are successfully added to the Config Generator page and properly integrated into the configuration system. This process maintains the integrity and functionality of the configuration management framework.

---

## Check Code Style

Before submitting the code, make sure it passes the code style check. Follow these steps:
Expand Down
4 changes: 4 additions & 0 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import unittest

import ray
import torch

from trinity.buffer.reader.sql_reader import SQLReader
Expand All @@ -22,6 +23,7 @@ def test_create_sql_buffer(self) -> None:
algorithm_type="ppo",
path=f"sqlite:///{db_path}",
storage_type=StorageType.SQL,
wrap_in_ray=True,
)
config = BufferConfig(
max_retry_times=3,
Expand All @@ -45,3 +47,5 @@ def test_create_sql_buffer(self) -> None:
for _ in range(total_num // read_batch_size):
exps = sql_reader.read()
self.assertEqual(len(exps), read_batch_size)
db_wrapper = ray.get_actor("sql-test_buffer")
self.assertIsNotNone(db_wrapper)
3 changes: 1 addition & 2 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
get_unittest_dataset_config,
)
from trinity.cli.launcher import explore
from trinity.common.constants import MonitorType


class BaseExplorerCase(RayUnittestBase):
Expand All @@ -23,7 +22,7 @@ def setUp(self):
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm_async"
self.config.algorithm.repeat_times = 2
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
self.config.monitor.monitor_type = "tensorboard"
self.config.project = "Trinity-unittest"
self.config.checkpoint_root_dir = get_checkpoint_path()
self.config.synchronizer.sync_interval = 2
Expand Down
6 changes: 3 additions & 3 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
import unittest
from typing import List
from typing import List, Tuple

import ray
import torch
Expand Down Expand Up @@ -87,8 +87,8 @@ def init_process_group(
def has_api_server(self) -> bool:
return True

def api_server_ready(self) -> str:
return "http://localhosts:12345"
def api_server_ready(self) -> Tuple[str, str]:
return "http://localhosts:12345", "placeholder"


class RunnerPoolTest(unittest.TestCase):
Expand Down
44 changes: 43 additions & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unittest.mock import MagicMock

from tests.tools import get_unittest_dataset_config
from trinity.common.workflows import MathWorkflow
from trinity.common.workflows import MathWorkflow, Workflow
from trinity.common.workflows.workflow import Task


Expand All @@ -15,6 +15,33 @@ class MockResponse:
reward: float = 0.0


class DummyWorkflow(Workflow):
def __init__(self, model, task: Task, auxiliary_models=None):
super().__init__(model, task, auxiliary_models)
self.obj = task.raw_task
self.output_format = task.workflow_args["output_format"]

@property
def resettable(self):
return True

def reset(self, task: Task):
self.obj = task.raw_task
self.output_format = task.workflow_args["output_format"]

def run(self):
if self.output_format == "json":
import json

return [json.dumps(self.obj)]
elif self.output_format == "yaml":
import yaml

return [yaml.safe_dump(self.obj)]
else:
raise ValueError("Invalid output format")


class WorkflowTest(unittest.TestCase):
def test_math_workflow(self) -> None:
model = MagicMock()
Expand Down Expand Up @@ -150,3 +177,18 @@ def test_gsm8k_workflow(self) -> None:
self.assertEqual(experiences[1].reward, -0.1)
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, 1.1)

def test_workflow_resettable(self) -> None:
model = MagicMock()
json_task = Task(
workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "json"}
)
yaml_task = Task(
workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "yaml"}
)
workflow = json_task.to_workflow(model)
answer = workflow.run()
self.assertEqual(answer[0], '{"a": 1}')
workflow.reset(yaml_task)
answer = workflow.run()
self.assertEqual(answer[0], "a: 1\n")
4 changes: 2 additions & 2 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_unittest_dataset_config,
)
from trinity.cli.launcher import bench, both, train
from trinity.common.constants import MonitorType, SyncMethod
from trinity.common.constants import SyncMethod


class BaseTrainerCase(RayUnittestBase):
Expand All @@ -30,7 +30,7 @@ def setUp(self):
self.config.explorer.rollout_model.use_v1 = False
self.config.project = "Trainer-unittest"
self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
self.config.monitor.monitor_type = "tensorboard"
self.config.checkpoint_root_dir = get_checkpoint_path()
self.config.synchronizer.sync_interval = 2
self.config.synchronizer.sync_method = SyncMethod.NCCL
Expand Down
Empty file added tests/utils/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions tests/utils/plugin_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest
from pathlib import Path

import ray

from trinity.common.workflows import WORKFLOWS
from trinity.utils.plugin_loader import load_plugins


@ray.remote
class PluginActor:
def run(self):
my_plugin_cls = WORKFLOWS.get("my_workflow")
return my_plugin_cls(None, None).run()


class TestPluginLoader(unittest.TestCase):
def test_load_plugins(self):
ray.init(ignore_reinit_error=True)
my_plugin_cls = WORKFLOWS.get("my_workflow")
self.assertIsNone(my_plugin_cls)
load_plugins(Path(__file__).resolve().parent / "plugins")
my_plugin_cls = WORKFLOWS.get("my_workflow")
self.assertIsNotNone(my_plugin_cls)
my_plugin = my_plugin_cls(None, None, None)
self.assertTrue(my_plugin.__module__.startswith("trinity.plugins"))
res = my_plugin.run()
self.assertEqual(res[0], "Hello world")
self.assertEqual(res[1], "Hi")
remote_plugin = PluginActor.remote()
remote_res = ray.get(remote_plugin.run.remote())
self.assertEqual(remote_res[0], "Hello world")
self.assertEqual(remote_res[1], "Hi")
ray.shutdown(_exiting_interpreter=True)
Empty file added tests/utils/plugins/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions tests/utils/plugins/my_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import List

from trinity.common.workflows import WORKFLOWS, Workflow


@WORKFLOWS.register_module("my_workflow")
class MyWorkflow(Workflow):
def __init__(self, model, task, auxiliary_models=None):
super().__init__(model, task, auxiliary_models)

def run(self) -> List:
return ["Hello world", "Hi"]
2 changes: 1 addition & 1 deletion trinity/buffer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig
file_read_type = algorithm_type
else:
file_read_type = "rollout"
return FILE_READERS.get(file_read_type)(storage_config, buffer_config)
return FILE_READERS.get(file_read_type)(storage_config, buffer_config) # type: ignore
else:
raise ValueError(f"{storage_config.storage_type} not supported.")

Expand Down
Loading