11import copy
22import os
3+ import subprocess
4+ import tempfile
35from typing import List
46
57import streamlit as st
@@ -21,8 +23,8 @@ class ConfigManager:
2123 def __init__ (self ):
2224 self ._init_default_config ()
2325 self .unfinished_fields = set ()
24- st .set_page_config (page_title = "Trainer Config Generator" , page_icon = ":robot:" )
25- st .title ("Trainer Config Generator" )
26+ st .set_page_config (page_title = "Trinity-RFT Config Generator" , page_icon = ":robot:" )
27+ st .title ("Trinity-RFT Config Generator" )
2628 if "_init_config_manager" not in st .session_state :
2729 self .reset_session_state ()
2830 self .maintain_session_state ()
@@ -36,6 +38,10 @@ def __init__(self):
3638 self .beginner_mode ()
3739 else :
3840 self .expert_mode ()
41+ if "config_generated" not in st .session_state :
42+ st .session_state .config_generated = False
43+ if "is_running" not in st .session_state :
44+ st .session_state .is_running = False
3945 self .generate_config ()
4046
4147 def _init_default_config (self ):
@@ -44,7 +50,7 @@ def _init_default_config(self):
4450 "mode" : "both" ,
4551 "project" : "Trinity-RFT" ,
4652 "exp_name" : "qwen2.5-1.5B" ,
47- "monitor_type" : MonitorType .WANDB .value ,
53+ "monitor_type" : MonitorType .TENSORBOARD .value ,
4854 # Model Configs
4955 "model_path" : "" ,
5056 "critic_model_path" : "" ,
@@ -1316,9 +1322,11 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
13161322 "lr" : st .session_state ["actor_lr" ],
13171323 "lr_warmup_steps_ratio" : st .session_state ["actor_lr_warmup_steps_ratio" ],
13181324 "warmup_style" : st .session_state ["actor_warmup_style" ],
1319- "total_training_steps" : - 1
1320- if st .session_state ["total_training_steps" ] is None
1321- else st .session_state ["total_training_steps" ],
1325+ "total_training_steps" : (
1326+ - 1
1327+ if st .session_state ["total_training_steps" ] is None
1328+ else st .session_state ["total_training_steps" ]
1329+ ),
13221330 },
13231331 "fsdp_config" : copy .deepcopy (fsdp_config ),
13241332 "tau" : st .session_state ["actor_tau" ],
@@ -1369,9 +1377,11 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
13691377 "lr" : st .session_state ["critic_lr" ],
13701378 "lr_warmup_steps_ratio" : st .session_state ["critic_lr_warmup_steps_ratio" ],
13711379 "warmup_style" : st .session_state ["critic_warmup_style" ],
1372- "total_training_steps" : - 1
1373- if st .session_state ["total_training_steps" ] is None
1374- else st .session_state ["total_training_steps" ],
1380+ "total_training_steps" : (
1381+ - 1
1382+ if st .session_state ["total_training_steps" ] is None
1383+ else st .session_state ["total_training_steps" ]
1384+ ),
13751385 },
13761386 "model" : {
13771387 "path" : critic_model_path ,
@@ -1436,7 +1446,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
14361446 "total_epochs" : st .session_state ["total_epochs" ],
14371447 "project_name" : st .session_state ["project" ],
14381448 "experiment_name" : st .session_state ["exp_name" ],
1439- "logger" : ["wandb " ],
1449+ "logger" : ["tensorboard " ],
14401450 "val_generations_to_log_to_wandb" : 0 ,
14411451 "nnodes" : trainer_nnodes ,
14421452 "n_gpus_per_node" : trainer_n_gpus_per_node ,
@@ -1516,7 +1526,12 @@ def generate_config(self):
15161526 "Generate Config" ,
15171527 disabled = disable_generate ,
15181528 help = help_messages ,
1529+ use_container_width = True ,
1530+ icon = ":material/create_new_folder:" ,
15191531 ):
1532+ st .session_state .config_generated = True
1533+ st .session_state .is_running = False
1534+ if st .session_state .config_generated :
15201535 config = {
15211536 "mode" : st .session_state ["mode" ],
15221537 "data" : {
@@ -1618,11 +1633,86 @@ def generate_config(self):
16181633 "dpo_dataset_chosen_key" : st .session_state ["dpo_dataset_chosen_key" ],
16191634 "dpo_dataset_rejected_key" : st .session_state ["dpo_dataset_rejected_key" ],
16201635 }
1636+ st .session_state .config_generated = True
16211637 st .header ("Generated Config File" )
1622- st .subheader ("Config File" )
1638+ buttons = st .container ()
1639+ save_btn , run_btn = buttons .columns (2 , vertical_alignment = "bottom" )
16231640 yaml_config = yaml .dump (config , allow_unicode = True , sort_keys = False )
1641+ save_btn .download_button (
1642+ "Save" ,
1643+ data = yaml_config ,
1644+ file_name = f"{ config ['monitor' ]['project' ]} -{ config ['monitor' ]['name' ]} .yaml" ,
1645+ mime = "text/plain" ,
1646+ icon = ":material/download:" ,
1647+ use_container_width = True ,
1648+ )
1649+ run_btn .button (
1650+ "Run" ,
1651+ on_click = self .run_config ,
1652+ args = (
1653+ buttons ,
1654+ yaml_config ,
1655+ ),
1656+ icon = ":material/terminal:" ,
1657+ use_container_width = True ,
1658+ disabled = st .session_state .is_running ,
1659+ )
16241660 st .code (yaml_config , language = "yaml" )
16251661
1662+ def run_config (self , parent , yaml_config : str ) -> None :
1663+ st .session_state .is_running = True
1664+
1665+ import ray
1666+
1667+ # first check if ray is running
1668+ ray_status = subprocess .run (
1669+ ["ray" , "status" ],
1670+ stdout = subprocess .PIPE ,
1671+ stderr = subprocess .PIPE ,
1672+ text = True ,
1673+ )
1674+
1675+ if ray_status .returncode != 0 :
1676+ parent .warning (
1677+ "Ray cluster is not running. Please start Ray first using `ray start --head`."
1678+ )
1679+ return
1680+ context = ray .init (ignore_reinit_error = True )
1681+ dashboard_url = context .dashboard_url
1682+ # save config to temp file
1683+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".yaml" , delete = False ) as tmpfile :
1684+ tmpfile .write (yaml_config )
1685+ tmpfile_path = tmpfile .name
1686+
1687+ # submit ray job
1688+ try :
1689+ subprocess .run (
1690+ [
1691+ "ray" ,
1692+ "job" ,
1693+ "submit" ,
1694+ "--no-wait" ,
1695+ "--" ,
1696+ "python" ,
1697+ "-m" ,
1698+ "trinity.cli.launcher" ,
1699+ "run" ,
1700+ "--config" ,
1701+ tmpfile_path ,
1702+ ],
1703+ text = True ,
1704+ capture_output = True ,
1705+ check = True ,
1706+ )
1707+ parent .success (
1708+ f"Job submitted successfully!\n \n "
1709+ f"View progress in the Ray Dashboard: http://{ dashboard_url } " ,
1710+ icon = "✅" ,
1711+ )
1712+ except subprocess .CalledProcessError as e :
1713+ parent .error (f"Failed to submit job:\n \n { e .stderr } " , icon = "❌" )
1714+ st .session_state .is_running = False
1715+
16261716
16271717if __name__ == "__main__" :
16281718 config_manager = ConfigManager ()
0 commit comments