Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,22 @@ For more details about dataset downloading, please refer to [Huggingface](https:
### Step 3: configurations


You may customize the configurations in [`examples`](examples/). For example, the model and dataset are specified as:
For convenience, Trinity-RFT provides a web interface for configuring your RFT process.

```yaml
model:
model_path: $MODEL_PATH/{model_name}
> [!NOTE]
> This is a experimental feature. We will continue to improve it and make it more user-friendly.

data:
dataset_path: $DATASET_PATH/{dataset_name}
```bash
trinity studio --port 8080
```

Please refer to [`examples`](examples/) for more details.

Then you can configure your RFT process in the web page and generate a config file.
You can save the config for later use or run it directly as described in the following section.


For advanced users, you can also manually configure your RFT process by editing the config file.
We provide a set of example config files in [`examples`](examples/).


### Step 4: run the RFT process
Expand All @@ -227,35 +231,25 @@ ray start --head
ray start --address=<master_address>
```



Optionally, we can login into [wandb](https://docs.wandb.ai/quickstart/) to better monitor the RFT process:

```shell
export WANDB_API_KEY=<your_api_key>
wandb login
```



Then, run the RFT process with the following command:
Then, for command-line users, run the RFT process with the following command:

```shell
trinity run --config <config_path>
```

> For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
> ```shell
> trinity run --config examples/grpo_gsm8k/gsm8k.yaml
> ```


For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:

```shell
trinity run --config examples/grpo_gsm8k/gsm8k.yaml
```



More example config files can be found in `examples`.

For studio users, just click the "Run" button in the web page.


For more detailed examples about how to use Trinity-RFT, please refer to the following tutorials:
Expand Down
23 changes: 22 additions & 1 deletion trinity/cli/launcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Launch the trainer"""

import argparse
import sys

import ray

Expand Down Expand Up @@ -153,6 +153,21 @@ def run(config_path: str):
both(config)


def studio(port: int = 8501):
from streamlit.web import cli as stcli

sys.argv = [
"streamlit",
"run",
"trinity/manager/config_manager.py",
"--server.port",
str(port),
"--server.fileWatcherType",
"none",
]
sys.exit(stcli.main())


def main() -> None:
"""The main entrypoint."""
parser = argparse.ArgumentParser()
Expand All @@ -162,12 +177,18 @@ def main() -> None:
run_parser = subparsers.add_parser("run", help="Run RFT process.")
run_parser.add_argument("--config", type=str, required=True, help="config file path.")

# studio command
studio_parser = subparsers.add_parser("studio", help="Run studio.")
studio_parser.add_argument("--port", type=int, default=8501, help="studio port.")

# TODO: add more commands like `monitor`, `label`

args = parser.parse_args()
if args.command == "run":
# TODO: support parse all args from command line
run(args.config)
elif args.command == "studio":
studio(args.port)


if __name__ == "__main__":
Expand Down
112 changes: 101 additions & 11 deletions trinity/manager/config_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import os
import subprocess
import tempfile
from typing import List

import streamlit as st
Expand All @@ -21,8 +23,8 @@ class ConfigManager:
def __init__(self):
self._init_default_config()
self.unfinished_fields = set()
st.set_page_config(page_title="Trainer Config Generator", page_icon=":robot:")
st.title("Trainer Config Generator")
st.set_page_config(page_title="Trinity-RFT Config Generator", page_icon=":robot:")
st.title("Trinity-RFT Config Generator")
if "_init_config_manager" not in st.session_state:
self.reset_session_state()
self.maintain_session_state()
Expand All @@ -36,6 +38,10 @@ def __init__(self):
self.beginner_mode()
else:
self.expert_mode()
if "config_generated" not in st.session_state:
st.session_state.config_generated = False
if "is_running" not in st.session_state:
st.session_state.is_running = False
self.generate_config()

def _init_default_config(self):
Expand All @@ -44,7 +50,7 @@ def _init_default_config(self):
"mode": "both",
"project": "Trinity-RFT",
"exp_name": "qwen2.5-1.5B",
"monitor_type": MonitorType.WANDB.value,
"monitor_type": MonitorType.TENSORBOARD.value,
# Model Configs
"model_path": "",
"critic_model_path": "",
Expand Down Expand Up @@ -1316,9 +1322,11 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
"lr": st.session_state["actor_lr"],
"lr_warmup_steps_ratio": st.session_state["actor_lr_warmup_steps_ratio"],
"warmup_style": st.session_state["actor_warmup_style"],
"total_training_steps": -1
if st.session_state["total_training_steps"] is None
else st.session_state["total_training_steps"],
"total_training_steps": (
-1
if st.session_state["total_training_steps"] is None
else st.session_state["total_training_steps"]
),
},
"fsdp_config": copy.deepcopy(fsdp_config),
"tau": st.session_state["actor_tau"],
Expand Down Expand Up @@ -1369,9 +1377,11 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
"lr": st.session_state["critic_lr"],
"lr_warmup_steps_ratio": st.session_state["critic_lr_warmup_steps_ratio"],
"warmup_style": st.session_state["critic_warmup_style"],
"total_training_steps": -1
if st.session_state["total_training_steps"] is None
else st.session_state["total_training_steps"],
"total_training_steps": (
-1
if st.session_state["total_training_steps"] is None
else st.session_state["total_training_steps"]
),
},
"model": {
"path": critic_model_path,
Expand Down Expand Up @@ -1436,7 +1446,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
"total_epochs": st.session_state["total_epochs"],
"project_name": st.session_state["project"],
"experiment_name": st.session_state["exp_name"],
"logger": ["wandb"],
"logger": ["tensorboard"],
"val_generations_to_log_to_wandb": 0,
"nnodes": trainer_nnodes,
"n_gpus_per_node": trainer_n_gpus_per_node,
Expand Down Expand Up @@ -1516,7 +1526,12 @@ def generate_config(self):
"Generate Config",
disabled=disable_generate,
help=help_messages,
use_container_width=True,
icon=":material/create_new_folder:",
):
st.session_state.config_generated = True
st.session_state.is_running = False
if st.session_state.config_generated:
config = {
"mode": st.session_state["mode"],
"data": {
Expand Down Expand Up @@ -1618,11 +1633,86 @@ def generate_config(self):
"dpo_dataset_chosen_key": st.session_state["dpo_dataset_chosen_key"],
"dpo_dataset_rejected_key": st.session_state["dpo_dataset_rejected_key"],
}
st.session_state.config_generated = True
st.header("Generated Config File")
st.subheader("Config File")
buttons = st.container()
save_btn, run_btn = buttons.columns(2, vertical_alignment="bottom")
yaml_config = yaml.dump(config, allow_unicode=True, sort_keys=False)
save_btn.download_button(
"Save",
data=yaml_config,
file_name=f"{config['monitor']['project']}-{config['monitor']['name']}.yaml",
mime="text/plain",
icon=":material/download:",
use_container_width=True,
)
run_btn.button(
"Run",
on_click=self.run_config,
args=(
buttons,
yaml_config,
),
icon=":material/terminal:",
use_container_width=True,
disabled=st.session_state.is_running,
)
st.code(yaml_config, language="yaml")

def run_config(self, parent, yaml_config: str) -> None:
st.session_state.is_running = True

import ray

# first check if ray is running
ray_status = subprocess.run(
["ray", "status"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)

if ray_status.returncode != 0:
parent.warning(
"Ray cluster is not running. Please start Ray first using `ray start --head`."
)
return
context = ray.init(ignore_reinit_error=True)
dashboard_url = context.dashboard_url
# save config to temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as tmpfile:
tmpfile.write(yaml_config)
tmpfile_path = tmpfile.name

# submit ray job
try:
subprocess.run(
[
"ray",
"job",
"submit",
"--no-wait",
"--",
"python",
"-m",
"trinity.cli.launcher",
"run",
"--config",
tmpfile_path,
],
text=True,
capture_output=True,
check=True,
)
parent.success(
f"Job submitted successfully!\n\n"
f"View progress in the Ray Dashboard: http://{dashboard_url}",
icon="✅",
)
except subprocess.CalledProcessError as e:
parent.error(f"Failed to submit job:\n\n{e.stderr}", icon="❌")
st.session_state.is_running = False


if __name__ == "__main__":
config_manager = ConfigManager()