Skip to content

Commit 9e72996

Browse files
authored
Fix Conflicts with main (#75)
1 parent 48f596a commit 9e72996

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2451
-1494
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ The configuration for each task dataset is defined as follows:
223223
- `temperature`: The temperature for sampling.
224224
- `default_workflow_type`: Type of workflow logic applied to this dataset. If not specified, the `buffer.default_workflow_type` is used.
225225
- `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used.
226+
- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.
227+
226228

227229
### Trainer Input
228230

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ To handle differences in `Task` contents, Trinity-RFT provides a unified `Task`
4545
- **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
4646
- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
4747
- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.
48+
- **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field.
49+
50+
```{tip}
51+
`workflow`, `workflow_args` and `raw_task` provide different levels of customization.
52+
53+
- `workflow` provides the global settings for all tasks that uses the same workflow. (Global Level)
54+
- `workflow_args` can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)
55+
- `raw_task` provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)
56+
```
4857

4958
In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example:
5059

@@ -111,7 +120,7 @@ During initialization, `Workflow` receives the following parameters:
111120
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
112121
```
113122

114-
Heres an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
123+
Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
115124

116125
```python
117126
class ExampleWorkflow(Workflow):
@@ -188,6 +197,25 @@ class ExampleWorkflow(Workflow):
188197
pass
189198
```
190199

200+
For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`:
201+
202+
```python
203+
# existing import lines
204+
from .example_workflow import ExampleWorkflow
205+
206+
__all__ = [
207+
# existing __all__ lines
208+
"ExampleWorkflow",
209+
]
210+
```
211+
212+
For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in `trinity/plugins`. Trinity-RFT will automatically detect and load all custom modules in this folder.
213+
214+
```{tip}
215+
You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `<Trinity_RFT_ROOT_DIR>/trinity/plugins` as the default directory.
216+
```
217+
218+
191219
#### Avoid Re-initialization
192220

193221
For heavy workflows, re-initializing every time can incurs extra computational costs.
@@ -286,6 +314,126 @@ trinity run --config <your_yaml_file>
286314

287315
---
288316

317+
## Adding New Config Entries for the Config Generator (Advanced)
318+
319+
### Step 0: Understanding Streamlit
320+
321+
Before adding new parameters to the Config Generator page, it is essential to familiarize yourself with the relevant API and mechanisms of [Streamlit](https://docs.streamlit.io/develop/api-reference). This project primarily utilizes various input components from Streamlit and employs `st.session_state` to store user-input parameters.
322+
323+
### Step 1: Implement New Config Entries
324+
325+
To illustrate the process of creating a new parameter setting for the Config Generator page, we will use `train_batch_size` as an example.
326+
327+
1. Determine the appropriate scope for the parameter. Currently, parameters are categorized into four files:
328+
- `trinity/manager/config_registry/buffer_config_manager.py`
329+
- `trinity/manager/config_registry/explorer_config_manager.py`
330+
- `trinity/manager/config_registry/model_config_manager.py`
331+
- `trinity/manager/config_registry/trainer_config_manager.py`
332+
333+
In this case, `train_batch_size` should be placed in the `buffer_config_manager.py` file.
334+
335+
2. Create a parameter setting function using Streamlit. The function name must follow the convention of starting with 'set_', and the remainder of the name becomes the config name.
336+
337+
3. Decorate the parameter setting function with the `CONFIG_GENERATORS.register_config` decorator. This decorator requires the following information:
338+
- Default value of the parameter
339+
- Visibility condition (if applicable)
340+
- Additional config parameters (if needed)
341+
342+
```{note}
343+
The `CONFIG_GENERATORS.register_config` decorator automatically passes `key=config_name` as an argument to the registered configuration function. Ensure that your function accepts this keyword argument.
344+
```
345+
346+
For `train_batch_size`, we will use the following settings:
347+
- Default value: 96
348+
- Visibility condition: `lambda: st.session_state["trainer_gpu_num"] > 0`
349+
- Additional config: `{"_train_batch_size_per_gpu": 16}`
350+
351+
352+
Here's the complete code for the `train_batch_size` parameter:
353+
354+
```python
355+
@CONFIG_GENERATORS.register_config(
356+
default_value=96,
357+
visible=lambda: st.session_state["trainer_gpu_num"] > 0,
358+
other_configs={"_train_batch_size_per_gpu": 16},
359+
)
360+
def set_train_batch_size(**kwargs):
361+
key = kwargs.get("key")
362+
trainer_gpu_num = st.session_state["trainer_gpu_num"]
363+
st.session_state[key] = (
364+
st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"]
365+
)
366+
367+
def on_change():
368+
st.session_state["_train_batch_size_per_gpu"] = max(
369+
st.session_state[key] // st.session_state["trainer_gpu_num"], 1
370+
)
371+
372+
st.number_input(
373+
"Train Batch Size",
374+
min_value=trainer_gpu_num,
375+
step=trainer_gpu_num,
376+
help=_str_for_train_batch_size(),
377+
on_change=on_change,
378+
**kwargs,
379+
)
380+
```
381+
382+
If the parameter requires validation, create a check function. For `train_batch_size`, we need to ensure it is divisible by `trainer_gpu_num`. If not, a warning should be displayed, and the parameter should be added to `unfinished_fields`.
383+
384+
Decorate the check function with the `CONFIG_GENERATORS.register_check` decorator:
385+
386+
```python
387+
@CONFIG_GENERATORS.register_check()
388+
def check_train_batch_size(unfinished_fields: set, key: str):
389+
if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0:
390+
unfinished_fields.add(key)
391+
st.warning(_str_for_train_batch_size())
392+
```
393+
394+
```{note}
395+
The `CONFIG_GENERATORS.register_check` decorator automatically receives `key=config_name` and `unfinished_fields=self.unfinished_fields` as arguments. Ensure your function accepts these keyword arguments.
396+
```
397+
398+
### Step 2: Integrating New Parameters into `config_manager.py`
399+
400+
To successfully integrate new parameters into the `config_manager.py` file, please adhere to the following procedure:
401+
402+
1. Parameter Categorization:
403+
Determine the appropriate section for the new parameter based on its functionality. The config generator page is structured into two primary modes:
404+
- Beginner Mode: Comprises "Essential Configs" and "Important Configs" sections.
405+
- Expert Mode: Includes "Model", "Buffer", "Explorer and Synchronizer", and "Trainer" sections.
406+
407+
2. Parameter Addition:
408+
Incorporate the new parameter into the relevant section using the `self.get_configs` method within the `ConfigManager` class.
409+
410+
Example:
411+
```python
412+
class ConfigManager:
413+
def _expert_buffer_part(self):
414+
self.get_configs("total_epochs", "train_batch_size")
415+
```
416+
417+
3. YAML File Integration:
418+
Locate the appropriate position for the new parameter within the YAML file structure. This should be done in the `generate_config` function and its associated sub-functions.
419+
420+
4. Parameter Value Assignment:
421+
Utilize `st.session_state` to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML.
422+
423+
Example:
424+
```python
425+
class ConfigManager:
426+
def _gen_buffer_config(self):
427+
buffer_config = {
428+
"batch_size": st.session_state["train_batch_size"],
429+
# Additional configuration parameters
430+
}
431+
```
432+
433+
By meticulously following these steps, you can ensure that new parameters are successfully added to the Config Generator page and properly integrated into the configuration system. This process maintains the integrity and functionality of the configuration management framework.
434+
435+
---
436+
289437
## Check Code Style
290438

291439
Before submitting the code, make sure it passes the code style check. Follow these steps:

tests/buffer/sql_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import unittest
33

4+
import ray
45
import torch
56

67
from trinity.buffer.reader.sql_reader import SQLReader
@@ -22,6 +23,7 @@ def test_create_sql_buffer(self) -> None:
2223
algorithm_type="ppo",
2324
path=f"sqlite:///{db_path}",
2425
storage_type=StorageType.SQL,
26+
wrap_in_ray=True,
2527
)
2628
config = BufferConfig(
2729
max_retry_times=3,
@@ -45,3 +47,5 @@ def test_create_sql_buffer(self) -> None:
4547
for _ in range(total_num // read_batch_size):
4648
exps = sql_reader.read()
4749
self.assertEqual(len(exps), read_batch_size)
50+
db_wrapper = ray.get_actor("sql-test_buffer")
51+
self.assertIsNotNone(db_wrapper)

tests/explorer/explorer_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
get_unittest_dataset_config,
1313
)
1414
from trinity.cli.launcher import explore
15-
from trinity.common.constants import MonitorType
1615

1716

1817
class BaseExplorerCase(RayUnittestBase):
@@ -23,7 +22,7 @@ def setUp(self):
2322
self.config.model.model_path = get_model_path()
2423
self.config.explorer.rollout_model.engine_type = "vllm_async"
2524
self.config.algorithm.repeat_times = 2
26-
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
25+
self.config.monitor.monitor_type = "tensorboard"
2726
self.config.project = "Trinity-unittest"
2827
self.config.checkpoint_root_dir = get_checkpoint_path()
2928
self.config.synchronizer.sync_interval = 2

tests/explorer/runner_pool_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import time
44
import unittest
5-
from typing import List
5+
from typing import List, Tuple
66

77
import ray
88
import torch
@@ -87,8 +87,8 @@ def init_process_group(
8787
def has_api_server(self) -> bool:
8888
return True
8989

90-
def api_server_ready(self) -> str:
91-
return "http://localhosts:12345"
90+
def api_server_ready(self) -> Tuple[str, str]:
91+
return "http://localhosts:12345", "placeholder"
9292

9393

9494
class RunnerPoolTest(unittest.TestCase):

tests/explorer/workflow_test.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from unittest.mock import MagicMock
66

77
from tests.tools import get_unittest_dataset_config
8-
from trinity.common.workflows import MathWorkflow
8+
from trinity.common.workflows import MathWorkflow, Workflow
99
from trinity.common.workflows.workflow import Task
1010

1111

@@ -15,6 +15,33 @@ class MockResponse:
1515
reward: float = 0.0
1616

1717

18+
class DummyWorkflow(Workflow):
19+
def __init__(self, model, task: Task, auxiliary_models=None):
20+
super().__init__(model, task, auxiliary_models)
21+
self.obj = task.raw_task
22+
self.output_format = task.workflow_args["output_format"]
23+
24+
@property
25+
def resettable(self):
26+
return True
27+
28+
def reset(self, task: Task):
29+
self.obj = task.raw_task
30+
self.output_format = task.workflow_args["output_format"]
31+
32+
def run(self):
33+
if self.output_format == "json":
34+
import json
35+
36+
return [json.dumps(self.obj)]
37+
elif self.output_format == "yaml":
38+
import yaml
39+
40+
return [yaml.safe_dump(self.obj)]
41+
else:
42+
raise ValueError("Invalid output format")
43+
44+
1845
class WorkflowTest(unittest.TestCase):
1946
def test_math_workflow(self) -> None:
2047
model = MagicMock()
@@ -150,3 +177,18 @@ def test_gsm8k_workflow(self) -> None:
150177
self.assertEqual(experiences[1].reward, -0.1)
151178
self.assertEqual(experiences[2].reward, -0.1)
152179
self.assertEqual(experiences[3].reward, 1.1)
180+
181+
def test_workflow_resettable(self) -> None:
182+
model = MagicMock()
183+
json_task = Task(
184+
workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "json"}
185+
)
186+
yaml_task = Task(
187+
workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "yaml"}
188+
)
189+
workflow = json_task.to_workflow(model)
190+
answer = workflow.run()
191+
self.assertEqual(answer[0], '{"a": 1}')
192+
workflow.reset(yaml_task)
193+
answer = workflow.run()
194+
self.assertEqual(answer[0], "a: 1\n")

tests/trainer/trainer_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
get_unittest_dataset_config,
1616
)
1717
from trinity.cli.launcher import bench, both, train
18-
from trinity.common.constants import MonitorType, SyncMethod
18+
from trinity.common.constants import SyncMethod
1919

2020

2121
class BaseTrainerCase(RayUnittestBase):
@@ -30,7 +30,7 @@ def setUp(self):
3030
self.config.explorer.rollout_model.use_v1 = False
3131
self.config.project = "Trainer-unittest"
3232
self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
33-
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
33+
self.config.monitor.monitor_type = "tensorboard"
3434
self.config.checkpoint_root_dir = get_checkpoint_path()
3535
self.config.synchronizer.sync_interval = 2
3636
self.config.synchronizer.sync_method = SyncMethod.NCCL

tests/utils/__init__.py

Whitespace-only changes.

tests/utils/plugin_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import unittest
2+
from pathlib import Path
3+
4+
import ray
5+
6+
from trinity.common.workflows import WORKFLOWS
7+
from trinity.utils.plugin_loader import load_plugins
8+
9+
10+
@ray.remote
11+
class PluginActor:
12+
def run(self):
13+
my_plugin_cls = WORKFLOWS.get("my_workflow")
14+
return my_plugin_cls(None, None).run()
15+
16+
17+
class TestPluginLoader(unittest.TestCase):
18+
def test_load_plugins(self):
19+
ray.init(ignore_reinit_error=True)
20+
my_plugin_cls = WORKFLOWS.get("my_workflow")
21+
self.assertIsNone(my_plugin_cls)
22+
load_plugins(Path(__file__).resolve().parent / "plugins")
23+
my_plugin_cls = WORKFLOWS.get("my_workflow")
24+
self.assertIsNotNone(my_plugin_cls)
25+
my_plugin = my_plugin_cls(None, None, None)
26+
self.assertTrue(my_plugin.__module__.startswith("trinity.plugins"))
27+
res = my_plugin.run()
28+
self.assertEqual(res[0], "Hello world")
29+
self.assertEqual(res[1], "Hi")
30+
remote_plugin = PluginActor.remote()
31+
remote_res = ray.get(remote_plugin.run.remote())
32+
self.assertEqual(remote_res[0], "Hello world")
33+
self.assertEqual(remote_res[1], "Hi")
34+
ray.shutdown(_exiting_interpreter=True)

tests/utils/plugins/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)