From e791d4a17938f41b922eab3da42f07fee27cc55d Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 9 May 2025 17:13:52 +0800 Subject: [PATCH 1/2] add studio cli --- README.md | 40 +++++------- trinity/cli/launcher.py | 23 ++++++- trinity/manager/config_manager.py | 103 +++++++++++++++++++++++++++--- 3 files changed, 132 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 9dc284bce4..95fa9084c6 100644 --- a/README.md +++ b/README.md @@ -200,18 +200,22 @@ For more details about dataset downloading, please refer to [Huggingface](https: ### Step 3: configurations -You may customize the configurations in [`examples`](examples/). For example, the model and dataset are specified as: +For convenience, Trinity-RFT provides a web interface for configuring your RFT process. -```yaml -model: - model_path: $MODEL_PATH/{model_name} +> [!NOTE] +> This is a experimental feature. We will continue to improve it and make it more user-friendly. -data: - dataset_path: $DATASET_PATH/{dataset_name} +```bash +trinity studio --port 8080 ``` -Please refer to [`examples`](examples/) for more details. +Then you can configure your RFT process in the web page and generate a config file. +You can save the config for later use or run it directly as described in the following section. + + +For advanced users, you can also manually configure your RFT process by editing the config file. +We provide a set of example config files in [`examples`](examples/). ### Step 4: run the RFT process @@ -227,8 +231,6 @@ ray start --head ray start --address= ``` - - Optionally, we can login into [wandb](https://docs.wandb.ai/quickstart/) to better monitor the RFT process: ```shell @@ -236,26 +238,18 @@ export WANDB_API_KEY= wandb login ``` - - -Then, run the RFT process with the following command: +Then, for command-line users, run the RFT process with the following command: ```shell trinity run --config ``` +> For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm: +> ```shell +> trinity run --config examples/grpo_gsm8k/gsm8k.yaml +> ``` - -For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm: - -```shell -trinity run --config examples/grpo_gsm8k/gsm8k.yaml -``` - - - -More example config files can be found in `examples`. - +For studio users, just click the "Run" button in the web page. For more detailed examples about how to use Trinity-RFT, please refer to the following tutorials: diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 2cf706b661..0dd846d8ec 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -1,6 +1,6 @@ """Launch the trainer""" - import argparse +import sys import ray @@ -153,6 +153,21 @@ def run(config_path: str): both(config) +def studio(port: int = 8501): + from streamlit.web import cli as stcli + + sys.argv = [ + "streamlit", + "run", + "trinity/manager/config_manager.py", + "--server.port", + str(port), + "--server.fileWatcherType", + "none", + ] + sys.exit(stcli.main()) + + def main() -> None: """The main entrypoint.""" parser = argparse.ArgumentParser() @@ -162,12 +177,18 @@ def main() -> None: run_parser = subparsers.add_parser("run", help="Run RFT process.") run_parser.add_argument("--config", type=str, required=True, help="config file path.") + # studio command + studio_parser = subparsers.add_parser("studio", help="Run studio.") + studio_parser.add_argument("--port", type=int, default=8501, help="studio port.") + # TODO: add more commands like `monitor`, `label` args = parser.parse_args() if args.command == "run": # TODO: support parse all args from command line run(args.config) + elif args.command == "studio": + studio(args.port) if __name__ == "__main__": diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 86232bbc50..2f0385584c 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -1,5 +1,7 @@ import copy import os +import subprocess +import tempfile from typing import List import streamlit as st @@ -21,8 +23,8 @@ class ConfigManager: def __init__(self): self._init_default_config() self.unfinished_fields = set() - st.set_page_config(page_title="Trainer Config Generator", page_icon=":robot:") - st.title("Trainer Config Generator") + st.set_page_config(page_title="Trinity-RFT Config Generator", page_icon=":robot:") + st.title("Trinity-RFT Config Generator") if "_init_config_manager" not in st.session_state: self.reset_session_state() self.maintain_session_state() @@ -36,6 +38,8 @@ def __init__(self): self.beginner_mode() else: self.expert_mode() + if "config_generated" not in st.session_state: + st.session_state.config_generated = False self.generate_config() def _init_default_config(self): @@ -1316,9 +1320,11 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "lr": st.session_state["actor_lr"], "lr_warmup_steps_ratio": st.session_state["actor_lr_warmup_steps_ratio"], "warmup_style": st.session_state["actor_warmup_style"], - "total_training_steps": -1 - if st.session_state["total_training_steps"] is None - else st.session_state["total_training_steps"], + "total_training_steps": ( + -1 + if st.session_state["total_training_steps"] is None + else st.session_state["total_training_steps"] + ), }, "fsdp_config": copy.deepcopy(fsdp_config), "tau": st.session_state["actor_tau"], @@ -1369,9 +1375,11 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "lr": st.session_state["critic_lr"], "lr_warmup_steps_ratio": st.session_state["critic_lr_warmup_steps_ratio"], "warmup_style": st.session_state["critic_warmup_style"], - "total_training_steps": -1 - if st.session_state["total_training_steps"] is None - else st.session_state["total_training_steps"], + "total_training_steps": ( + -1 + if st.session_state["total_training_steps"] is None + else st.session_state["total_training_steps"] + ), }, "model": { "path": critic_model_path, @@ -1436,7 +1444,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "total_epochs": st.session_state["total_epochs"], "project_name": st.session_state["project"], "experiment_name": st.session_state["exp_name"], - "logger": ["wandb"], + "logger": ["tensorboard"], "val_generations_to_log_to_wandb": 0, "nnodes": trainer_nnodes, "n_gpus_per_node": trainer_n_gpus_per_node, @@ -1516,7 +1524,11 @@ def generate_config(self): "Generate Config", disabled=disable_generate, help=help_messages, + use_container_width=True, + icon=":material/create_new_folder:", ): + st.session_state.config_generated = True + if st.session_state.config_generated: config = { "mode": st.session_state["mode"], "data": { @@ -1618,11 +1630,82 @@ def generate_config(self): "dpo_dataset_chosen_key": st.session_state["dpo_dataset_chosen_key"], "dpo_dataset_rejected_key": st.session_state["dpo_dataset_rejected_key"], } + st.session_state.config_generated = True st.header("Generated Config File") - st.subheader("Config File") + buttons = st.container() + save_btn, run_btn = buttons.columns(2, vertical_alignment="bottom") yaml_config = yaml.dump(config, allow_unicode=True, sort_keys=False) + save_btn.download_button( + "Save", + data=yaml_config, + file_name=f"{config['monitor']['project']}-{config['monitor']['name']}.yaml", + mime="text/plain", + icon=":material/download:", + use_container_width=True, + ) + run_btn.button( + "Run", + on_click=self.run_config, + args=( + buttons, + yaml_config, + ), + icon=":material/terminal:", + use_container_width=True, + ) st.code(yaml_config, language="yaml") + def run_config(self, parent, yaml_config: str) -> None: + import ray + + # first check if ray is running + ray_status = subprocess.run( + ["ray", "status"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + if ray_status.returncode != 0: + parent.warning( + "Ray cluster is not running. Please start Ray first using `ray start --head`." + ) + return + context = ray.init(ignore_reinit_error=True) + dashboard_url = context.dashboard_url + # save config to temp file + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as tmpfile: + tmpfile.write(yaml_config) + tmpfile_path = tmpfile.name + + # submit ray job + try: + subprocess.run( + [ + "ray", + "job", + "submit", + "--no-wait", + "--", + "python", + "-m", + "trinity.cli.launcher", + "run", + "--config", + tmpfile_path, + ], + text=True, + capture_output=True, + check=True, + ) + parent.success( + f"Job submitted successfully!\n\n" + f"View progress in the Ray Dashboard: {dashboard_url}", + icon="✅", + ) + except subprocess.CalledProcessError as e: + parent.error(f"❌ Failed to submit job:\n\n{e.stderr}") + if __name__ == "__main__": config_manager = ConfigManager() From f8de7fc11f8b923a2241bd360fc7a38a28391c4e Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 9 May 2025 18:01:35 +0800 Subject: [PATCH 2/2] disable run button when run success --- trinity/manager/config_manager.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 2f0385584c..0045deac25 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -40,6 +40,8 @@ def __init__(self): self.expert_mode() if "config_generated" not in st.session_state: st.session_state.config_generated = False + if "is_running" not in st.session_state: + st.session_state.is_running = False self.generate_config() def _init_default_config(self): @@ -48,7 +50,7 @@ def _init_default_config(self): "mode": "both", "project": "Trinity-RFT", "exp_name": "qwen2.5-1.5B", - "monitor_type": MonitorType.WANDB.value, + "monitor_type": MonitorType.TENSORBOARD.value, # Model Configs "model_path": "", "critic_model_path": "", @@ -1528,6 +1530,7 @@ def generate_config(self): icon=":material/create_new_folder:", ): st.session_state.config_generated = True + st.session_state.is_running = False if st.session_state.config_generated: config = { "mode": st.session_state["mode"], @@ -1652,10 +1655,13 @@ def generate_config(self): ), icon=":material/terminal:", use_container_width=True, + disabled=st.session_state.is_running, ) st.code(yaml_config, language="yaml") def run_config(self, parent, yaml_config: str) -> None: + st.session_state.is_running = True + import ray # first check if ray is running @@ -1700,11 +1706,12 @@ def run_config(self, parent, yaml_config: str) -> None: ) parent.success( f"Job submitted successfully!\n\n" - f"View progress in the Ray Dashboard: {dashboard_url}", + f"View progress in the Ray Dashboard: http://{dashboard_url}", icon="✅", ) except subprocess.CalledProcessError as e: - parent.error(f"❌ Failed to submit job:\n\n{e.stderr}") + parent.error(f"Failed to submit job:\n\n{e.stderr}", icon="❌") + st.session_state.is_running = False if __name__ == "__main__":