Skip to content

Commit 375ad7d

Browse files
authored
Refactor Storage (agentscope-ai#227)
1 parent c5e6dd3 commit 375ad7d

File tree

77 files changed

+1029
-977
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+1029
-977
lines changed

benchmark/config/countdown-template.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ buffer:
4646
priority_fn: linear_decay
4747
decay: 0.1
4848
sft_warmup_steps: 0
49-
max_retry_times: 3
50-
max_retry_interval: 1
5149
explorer:
5250
runner_num: 32
5351
max_timeout: 900

benchmark/config/gsm8k-template.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ buffer:
5151
priority_fn: linear_decay
5252
decay: 0.1
5353
sft_warmup_steps: 0
54-
max_retry_times: 3
55-
max_retry_interval: 1
5654
explorer:
5755
runner_per_model: 8
5856
max_timeout: 900

docs/sphinx_doc/source/_templates/versions.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
33
<span class="rst-current-version" data-toggle="rst-current-version">
44
<span class="fa fa-book"> Other Versions</span>
5-
v: {{ current_version.name }}
5+
<b>{{ current_version.name }}</b>
66
<span class="fa fa-caret-down"></span>
77
</span>
88
<div class="rst-other-versions">
@@ -18,7 +18,7 @@
1818
<dl>
1919
<dt>Branches</dt>
2020
{%- for item in versions.branches %}
21-
<dd><a href="{{ item.url }}">{{ item.name }}</a> <b>(latest)</b></dd>
21+
<dd><b><a href="{{ item.url }}">{{ item.name }}</a> (latest)</b></dd>
2222
{%- endfor %}
2323
</dl>
2424
{%- endif %}

docs/sphinx_doc/source/tutorial/example_mix_algo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class MIXAlgorithm(AlgorithmType):
5454
use_reference: bool = True
5555
compute_advantage_in_trainer: bool = False
5656
can_balance_batch: bool = True
57-
schema: type = ExperienceModel
57+
schema: str = "experience"
5858

5959
@classmethod
6060
def default_config(cls) -> Dict:

docs/sphinx_doc/source/tutorial/example_step_wise.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ buffer:
107107
total_epochs: 20
108108
batch_size: 16
109109
train_batch_size: 7680 # here: batch_size * repeat_times * max_env_steps
110-
max_retry_times: 3
111-
max_retry_interval: 1
112110
explorer_input:
113111
taskset:
114112
name: alfworld

docs/sphinx_doc/source/tutorial/faq.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ from sqlalchemy import create_engine
120120
from sqlalchemy.exc import OperationalError
121121
from sqlalchemy.orm import sessionmaker
122122
from sqlalchemy.pool import NullPool
123-
from trinity.common.schema import ExperienceModel
123+
from trinity.common.schema.sql_schema import ExperienceModel
124124

125125
engine = create_engine(buffer.trainer_input.experience_buffer.path)
126126
session = sessionmaker(bind=engine)
@@ -129,7 +129,6 @@ sess = session()
129129
MAX_EXPERIENCES = 4
130130
experiences = (
131131
sess.query(ExperienceModel)
132-
.with_for_update()
133132
.limit(MAX_EXPERIENCES)
134133
.all()
135134
)

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ synchronizer:
3535
# Model weight synchronization settings
3636
...
3737
monitor:
38-
# Monitoring configurations (e.g., WandB or TensorBoard)
38+
# Monitoring configurations (e.g., WandB, TensorBoard or MLFlow)
3939
...
4040
service:
4141
# Services to use
@@ -48,10 +48,12 @@ log:
4848
...
4949
```
5050

51-
Each of these sections will be explained in detail below.
51+
Each of these sections will be explained in detail below. For additional details about specific parameters not covered here, please refer to the [source code](https://github.com/modelscope/Trinity-RFT/blob/main/trinity/common/config.py).
5252

53-
```{note}
54-
For additional details about specific parameters not covered here, please refer to the [source code](https://github.com/modelscope/Trinity-RFT/blob/main/trinity/common/config.py).
53+
```{tip}
54+
Trinity-RFT uses [OmegaConf](https://omegaconf.readthedocs.io/en/latest/) to load YAML configuration files.
55+
It supports some advanced features like [variable interpolation](https://omegaconf.readthedocs.io/en/latest/usage.html#variable-interpolation) and [environment variable substitution](https://omegaconf.readthedocs.io/en/latest/custom_resolvers.html#oc-env).
56+
Users can use these features to simplify configuration.
5557
```
5658

5759
---
@@ -64,7 +66,7 @@ These are general settings that apply to the entire experiment.
6466
project: Trinity-RFT
6567
name: example
6668
mode: both
67-
checkpoint_root_dir: /PATH/TO/CHECKPOINT
69+
checkpoint_root_dir: ${oc.env:CHECKPOINT_ROOT_DIR} # CHECKPOINT_ROOT_DIR is an environment variable set in advance
6870
```
6971
7072
- `project`: The name of the project.
@@ -115,13 +117,25 @@ Used to log training metrics during execution.
115117
```yaml
116118
monitor:
117119
monitor_type: wandb
120+
monitor_args:
121+
base_url: http://localhost:8080
122+
api_key: your_api_key
118123
enable_ray_timeline: False
119124
```
120125

121126
- `monitor_type`: Type of monitoring system. Options:
122127
- `wandb`: Logs to [Weights & Biases](https://docs.wandb.ai/quickstart/). Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs.
123128
- `tensorboard`: Logs to [TensorBoard](https://www.tensorflow.org/tensorboard). Files are saved under `<checkpoint_root_dir>/<project>/<name>/monitor/tensorboard`.
124-
- `enable_ray_timeline`: Whether to export the ray timeline. If set to `True`, a `timeline.json` file will be exported to `<checkpoint_root_dir>/<project>/<name>/monitor`. You can view the timeline file in Chrome at [chrome://tracing](chrome://tracing).
129+
- `mlflow`: Logs to [MLFlow](https://mlflow.org/). If [MLFlow authentication](https://mlflow.org/docs/latest/ml/auth/) is setup, set `MLFLOW_TRACKING_USERNAME` and `MLFLOW_TRACKING_PASSWORD` as environment variables before running.
130+
- `monitor_args`: Dictionary of arguments for monitor initialization.
131+
- For `wandb`:
132+
- `base_url`: Overrides `WANDB_BASE_URL` if set.
133+
- `api_key`: Overrides `WANDB_API_KEY` if set.
134+
- For `mlflow`:
135+
- `uri`: The URI of your MLFlow instance. Strongly recommended to set; defaults to `http://localhost:5000`.
136+
- `username`: Overrides `MLFLOW_TRACKING_USERNAME` if set.
137+
- `password`: Overrides `MLFLOW_TRACKING_PASSWORD` if set.
138+
- `enable_ray_timeline`: If `True`, exports a `timeline.json` file to `<checkpoint_root_dir>/<project>/<name>/monitor`. Viewable in Chrome at [chrome://tracing](chrome://tracing).
125139

126140
---
127141

@@ -131,8 +145,8 @@ Defines the model paths and token limits.
131145

132146
```yaml
133147
model:
134-
model_path: /PATH/TO/MODEL/
135-
critic_model_path: ''
148+
model_path: ${oc.env:MODEL_PATH} # MODEL_PATH is an environment variable set in advance
149+
critic_model_path: ${model.model_path} # use the value of model.model_path
136150
max_response_tokens: 16384
137151
max_model_len: 20480
138152
```
@@ -174,10 +188,6 @@ buffer:
174188
...
175189
eval_tasksets:
176190
...
177-
178-
explorer_output:
179-
...
180-
181191
trainer_input:
182192
experience_buffer:
183193
...
@@ -255,41 +265,6 @@ The configuration for each task dataset is defined as follows:
255265
- `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used.
256266
- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.
257267

258-
259-
### Explorer Output
260-
261-
In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_output`, rather than using `buffer.trainer_input`, which will be introduced in the next section.
262-
263-
```{note}
264-
For `both` and `train` modes, users should use `buffer.trainer_input.experience_buffer` instead of `buffer.explorer_output`.
265-
```
266-
267-
```yaml
268-
buffer:
269-
...
270-
explorer_output:
271-
name: countdown_buffer
272-
storage_type: queue
273-
path: sqlite:///countdown_buffer.db
274-
wrap_in_ray: True
275-
max_read_timeout: 1800
276-
```
277-
278-
- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique.
279-
- `storage_type`: The storage type for the experience buffer.
280-
- `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases.
281-
- `sql`: Experience data is stored in a SQL database. If your database only supports local access (e.g., SQLite), set `wrap_in_ray` to `True` to wrap the database in a Ray actor, enabling remote access from other nodes.
282-
- `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode.
283-
- `path`: The path to the experience buffer.
284-
- For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
285-
- For `file` storage type, the path points to the directory containing the dataset files.
286-
- For `sql` storage type, the path points to the SQLite database file.
287-
- `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor.
288-
- `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes).
289-
- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`.
290-
- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.
291-
292-
293268
### Trainer Input
294269

295270
Defines the experience buffer and optional SFT warm-up dataset.
@@ -314,7 +289,19 @@ buffer:
314289
sft_warmup_steps: 0
315290
```
316291

317-
- `experience_buffer`: Experience buffer used by the trainer, which is logically equivalent to `buffer.explorer_output`.
292+
- `experience_buffer`: It is the input of Trainer and also the output of Explorer. This field is required even in explore mode.
293+
- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique.
294+
- `storage_type`: The storage type for the experience buffer.
295+
- `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases.
296+
- `sql`: Experience data is stored in a SQL database.
297+
- `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode.
298+
- `path`: The path to the experience buffer.
299+
- For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
300+
- For `file` storage type, the path points to the directory containing the dataset files.
301+
- For `sql` storage type, the path points to the SQLite database file.
302+
- `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes).
303+
- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`.
304+
- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.
318305
- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup).
319306
- `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins.
320307

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ class OPMDPolicyLossFn(PolicyLossFn):
447447

448448
The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.
449449

450-
To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {object}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.
450+
To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {class}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.
451451

452452
The `AlgorithmType` class includes the following attributes and methods:
453453

@@ -473,7 +473,7 @@ class OPMDAlgorithm(AlgorithmType):
473473
use_reference: bool = True
474474
compute_advantage_in_trainer: bool = False
475475
can_balance_batch: bool = True
476-
schema: type = ExperienceModel
476+
schema: str = "experience"
477477

478478
@classmethod
479479
def default_config(cls) -> Dict:

examples/RAFT_alfworld/RAFT_alfworld_7B.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ cluster:
1515
buffer:
1616
total_epochs: 30
1717
batch_size: 80
18-
max_retry_times: 1
19-
max_retry_interval: 1
2018
explorer_input:
2119
taskset:
2220
name: alfworld-train

examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ cluster:
1515
buffer:
1616
total_epochs: 30
1717
batch_size: 80
18-
max_retry_times: 1
19-
max_retry_interval: 1
2018
explorer_input:
2119
taskset:
2220
name: alfworld-train

0 commit comments

Comments
 (0)