Skip to content

Commit b986f3e

Browse files
authored
Add CLI for Trinity Studio (#32)
1 parent a182b1f commit b986f3e

File tree

3 files changed

+140
-35
lines changed

3 files changed

+140
-35
lines changed

README.md

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,22 @@ For more details about dataset downloading, please refer to [Huggingface](https:
200200
### Step 3: configurations
201201

202202

203-
You may customize the configurations in [`examples`](examples/). For example, the model and dataset are specified as:
203+
For convenience, Trinity-RFT provides a web interface for configuring your RFT process.
204204

205-
```yaml
206-
model:
207-
model_path: $MODEL_PATH/{model_name}
205+
> [!NOTE]
206+
> This is a experimental feature. We will continue to improve it and make it more user-friendly.
208207
209-
data:
210-
dataset_path: $DATASET_PATH/{dataset_name}
208+
```bash
209+
trinity studio --port 8080
211210
```
212211

213-
Please refer to [`examples`](examples/) for more details.
214212

213+
Then you can configure your RFT process in the web page and generate a config file.
214+
You can save the config for later use or run it directly as described in the following section.
215+
216+
217+
For advanced users, you can also manually configure your RFT process by editing the config file.
218+
We provide a set of example config files in [`examples`](examples/).
215219

216220

217221
### Step 4: run the RFT process
@@ -227,35 +231,25 @@ ray start --head
227231
ray start --address=<master_address>
228232
```
229233

230-
231-
232234
Optionally, we can login into [wandb](https://docs.wandb.ai/quickstart/) to better monitor the RFT process:
233235

234236
```shell
235237
export WANDB_API_KEY=<your_api_key>
236238
wandb login
237239
```
238240

239-
240-
241-
Then, run the RFT process with the following command:
241+
Then, for command-line users, run the RFT process with the following command:
242242

243243
```shell
244244
trinity run --config <config_path>
245245
```
246246

247+
> For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
248+
> ```shell
249+
> trinity run --config examples/grpo_gsm8k/gsm8k.yaml
250+
> ```
247251
248-
249-
For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
250-
251-
```shell
252-
trinity run --config examples/grpo_gsm8k/gsm8k.yaml
253-
```
254-
255-
256-
257-
More example config files can be found in `examples`.
258-
252+
For studio users, just click the "Run" button in the web page.
259253
260254
261255
For more detailed examples about how to use Trinity-RFT, please refer to the following tutorials:

trinity/cli/launcher.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Launch the trainer"""
2-
32
import argparse
3+
import sys
44

55
import ray
66

@@ -153,6 +153,21 @@ def run(config_path: str):
153153
both(config)
154154

155155

156+
def studio(port: int = 8501):
157+
from streamlit.web import cli as stcli
158+
159+
sys.argv = [
160+
"streamlit",
161+
"run",
162+
"trinity/manager/config_manager.py",
163+
"--server.port",
164+
str(port),
165+
"--server.fileWatcherType",
166+
"none",
167+
]
168+
sys.exit(stcli.main())
169+
170+
156171
def main() -> None:
157172
"""The main entrypoint."""
158173
parser = argparse.ArgumentParser()
@@ -162,12 +177,18 @@ def main() -> None:
162177
run_parser = subparsers.add_parser("run", help="Run RFT process.")
163178
run_parser.add_argument("--config", type=str, required=True, help="config file path.")
164179

180+
# studio command
181+
studio_parser = subparsers.add_parser("studio", help="Run studio.")
182+
studio_parser.add_argument("--port", type=int, default=8501, help="studio port.")
183+
165184
# TODO: add more commands like `monitor`, `label`
166185

167186
args = parser.parse_args()
168187
if args.command == "run":
169188
# TODO: support parse all args from command line
170189
run(args.config)
190+
elif args.command == "studio":
191+
studio(args.port)
171192

172193

173194
if __name__ == "__main__":

trinity/manager/config_manager.py

Lines changed: 101 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import copy
22
import os
3+
import subprocess
4+
import tempfile
35
from typing import List
46

57
import streamlit as st
@@ -21,8 +23,8 @@ class ConfigManager:
2123
def __init__(self):
2224
self._init_default_config()
2325
self.unfinished_fields = set()
24-
st.set_page_config(page_title="Trainer Config Generator", page_icon=":robot:")
25-
st.title("Trainer Config Generator")
26+
st.set_page_config(page_title="Trinity-RFT Config Generator", page_icon=":robot:")
27+
st.title("Trinity-RFT Config Generator")
2628
if "_init_config_manager" not in st.session_state:
2729
self.reset_session_state()
2830
self.maintain_session_state()
@@ -36,6 +38,10 @@ def __init__(self):
3638
self.beginner_mode()
3739
else:
3840
self.expert_mode()
41+
if "config_generated" not in st.session_state:
42+
st.session_state.config_generated = False
43+
if "is_running" not in st.session_state:
44+
st.session_state.is_running = False
3945
self.generate_config()
4046

4147
def _init_default_config(self):
@@ -44,7 +50,7 @@ def _init_default_config(self):
4450
"mode": "both",
4551
"project": "Trinity-RFT",
4652
"exp_name": "qwen2.5-1.5B",
47-
"monitor_type": MonitorType.WANDB.value,
53+
"monitor_type": MonitorType.TENSORBOARD.value,
4854
# Model Configs
4955
"model_path": "",
5056
"critic_model_path": "",
@@ -1316,9 +1322,11 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
13161322
"lr": st.session_state["actor_lr"],
13171323
"lr_warmup_steps_ratio": st.session_state["actor_lr_warmup_steps_ratio"],
13181324
"warmup_style": st.session_state["actor_warmup_style"],
1319-
"total_training_steps": -1
1320-
if st.session_state["total_training_steps"] is None
1321-
else st.session_state["total_training_steps"],
1325+
"total_training_steps": (
1326+
-1
1327+
if st.session_state["total_training_steps"] is None
1328+
else st.session_state["total_training_steps"]
1329+
),
13221330
},
13231331
"fsdp_config": copy.deepcopy(fsdp_config),
13241332
"tau": st.session_state["actor_tau"],
@@ -1369,9 +1377,11 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
13691377
"lr": st.session_state["critic_lr"],
13701378
"lr_warmup_steps_ratio": st.session_state["critic_lr_warmup_steps_ratio"],
13711379
"warmup_style": st.session_state["critic_warmup_style"],
1372-
"total_training_steps": -1
1373-
if st.session_state["total_training_steps"] is None
1374-
else st.session_state["total_training_steps"],
1380+
"total_training_steps": (
1381+
-1
1382+
if st.session_state["total_training_steps"] is None
1383+
else st.session_state["total_training_steps"]
1384+
),
13751385
},
13761386
"model": {
13771387
"path": critic_model_path,
@@ -1436,7 +1446,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
14361446
"total_epochs": st.session_state["total_epochs"],
14371447
"project_name": st.session_state["project"],
14381448
"experiment_name": st.session_state["exp_name"],
1439-
"logger": ["wandb"],
1449+
"logger": ["tensorboard"],
14401450
"val_generations_to_log_to_wandb": 0,
14411451
"nnodes": trainer_nnodes,
14421452
"n_gpus_per_node": trainer_n_gpus_per_node,
@@ -1516,7 +1526,12 @@ def generate_config(self):
15161526
"Generate Config",
15171527
disabled=disable_generate,
15181528
help=help_messages,
1529+
use_container_width=True,
1530+
icon=":material/create_new_folder:",
15191531
):
1532+
st.session_state.config_generated = True
1533+
st.session_state.is_running = False
1534+
if st.session_state.config_generated:
15201535
config = {
15211536
"mode": st.session_state["mode"],
15221537
"data": {
@@ -1618,11 +1633,86 @@ def generate_config(self):
16181633
"dpo_dataset_chosen_key": st.session_state["dpo_dataset_chosen_key"],
16191634
"dpo_dataset_rejected_key": st.session_state["dpo_dataset_rejected_key"],
16201635
}
1636+
st.session_state.config_generated = True
16211637
st.header("Generated Config File")
1622-
st.subheader("Config File")
1638+
buttons = st.container()
1639+
save_btn, run_btn = buttons.columns(2, vertical_alignment="bottom")
16231640
yaml_config = yaml.dump(config, allow_unicode=True, sort_keys=False)
1641+
save_btn.download_button(
1642+
"Save",
1643+
data=yaml_config,
1644+
file_name=f"{config['monitor']['project']}-{config['monitor']['name']}.yaml",
1645+
mime="text/plain",
1646+
icon=":material/download:",
1647+
use_container_width=True,
1648+
)
1649+
run_btn.button(
1650+
"Run",
1651+
on_click=self.run_config,
1652+
args=(
1653+
buttons,
1654+
yaml_config,
1655+
),
1656+
icon=":material/terminal:",
1657+
use_container_width=True,
1658+
disabled=st.session_state.is_running,
1659+
)
16241660
st.code(yaml_config, language="yaml")
16251661

1662+
def run_config(self, parent, yaml_config: str) -> None:
1663+
st.session_state.is_running = True
1664+
1665+
import ray
1666+
1667+
# first check if ray is running
1668+
ray_status = subprocess.run(
1669+
["ray", "status"],
1670+
stdout=subprocess.PIPE,
1671+
stderr=subprocess.PIPE,
1672+
text=True,
1673+
)
1674+
1675+
if ray_status.returncode != 0:
1676+
parent.warning(
1677+
"Ray cluster is not running. Please start Ray first using `ray start --head`."
1678+
)
1679+
return
1680+
context = ray.init(ignore_reinit_error=True)
1681+
dashboard_url = context.dashboard_url
1682+
# save config to temp file
1683+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as tmpfile:
1684+
tmpfile.write(yaml_config)
1685+
tmpfile_path = tmpfile.name
1686+
1687+
# submit ray job
1688+
try:
1689+
subprocess.run(
1690+
[
1691+
"ray",
1692+
"job",
1693+
"submit",
1694+
"--no-wait",
1695+
"--",
1696+
"python",
1697+
"-m",
1698+
"trinity.cli.launcher",
1699+
"run",
1700+
"--config",
1701+
tmpfile_path,
1702+
],
1703+
text=True,
1704+
capture_output=True,
1705+
check=True,
1706+
)
1707+
parent.success(
1708+
f"Job submitted successfully!\n\n"
1709+
f"View progress in the Ray Dashboard: http://{dashboard_url}",
1710+
icon="✅",
1711+
)
1712+
except subprocess.CalledProcessError as e:
1713+
parent.error(f"Failed to submit job:\n\n{e.stderr}", icon="❌")
1714+
st.session_state.is_running = False
1715+
16261716

16271717
if __name__ == "__main__":
16281718
config_manager = ConfigManager()

0 commit comments

Comments
 (0)