diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 2e4daeab0b..4fbc1e468b 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -286,6 +286,126 @@ trinity run --config --- +## Adding New Config Entries for the Config Generator (Advanced) + +### Step 0: Understanding Streamlit + +Before adding new parameters to the Config Generator page, it is essential to familiarize yourself with the relevant API and mechanisms of [Streamlit](https://docs.streamlit.io/develop/api-reference). This project primarily utilizes various input components from Streamlit and employs `st.session_state` to store user-input parameters. + +### Step 1: Implement New Config Entries + +To illustrate the process of creating a new parameter setting for the Config Generator page, we will use `train_batch_size` as an example. + +1. Determine the appropriate scope for the parameter. Currently, parameters are categorized into four files: + - `trinity/manager/config_registry/buffer_config_manager.py` + - `trinity/manager/config_registry/explorer_config_manager.py` + - `trinity/manager/config_registry/model_config_manager.py` + - `trinity/manager/config_registry/trainer_config_manager.py` + + In this case, `train_batch_size` should be placed in the `buffer_config_manager.py` file. + +2. Create a parameter setting function using Streamlit. The function name must follow the convention of starting with 'set_', and the remainder of the name becomes the config name. + +3. Decorate the parameter setting function with the `CONFIG_GENERATORS.register_config` decorator. This decorator requires the following information: + - Default value of the parameter + - Visibility condition (if applicable) + - Additional config parameters (if needed) + +```{note} +The `CONFIG_GENERATORS.register_config` decorator automatically passes `key=config_name` as an argument to the registered configuration function. Ensure that your function accepts this keyword argument. +``` + +For `train_batch_size`, we will use the following settings: +- Default value: 96 +- Visibility condition: `lambda: st.session_state["trainer_gpu_num"] > 0` +- Additional config: `{"_train_batch_size_per_gpu": 16}` + + +Here's the complete code for the `train_batch_size` parameter: + +```python +@CONFIG_GENERATORS.register_config( + default_value=96, + visible=lambda: st.session_state["trainer_gpu_num"] > 0, + other_configs={"_train_batch_size_per_gpu": 16}, +) +def set_train_batch_size(**kwargs): + key = kwargs.get("key") + trainer_gpu_num = st.session_state["trainer_gpu_num"] + st.session_state[key] = ( + st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"] + ) + + def on_change(): + st.session_state["_train_batch_size_per_gpu"] = max( + st.session_state[key] // st.session_state["trainer_gpu_num"], 1 + ) + + st.number_input( + "Train Batch Size", + min_value=trainer_gpu_num, + step=trainer_gpu_num, + help=_str_for_train_batch_size(), + on_change=on_change, + **kwargs, + ) +``` + +If the parameter requires validation, create a check function. For `train_batch_size`, we need to ensure it is divisible by `trainer_gpu_num`. If not, a warning should be displayed, and the parameter should be added to `unfinished_fields`. + +Decorate the check function with the `CONFIG_GENERATORS.register_check` decorator: + +```python +@CONFIG_GENERATORS.register_check() +def check_train_batch_size(unfinished_fields: set, key: str): + if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0: + unfinished_fields.add(key) + st.warning(_str_for_train_batch_size()) +``` + +```{note} +The `CONFIG_GENERATORS.register_check` decorator automatically receives `key=config_name` and `unfinished_fields=self.unfinished_fields` as arguments. Ensure your function accepts these keyword arguments. +``` + +### Step 2: Integrating New Parameters into `config_manager.py` + +To successfully integrate new parameters into the `config_manager.py` file, please adhere to the following procedure: + +1. Parameter Categorization: + Determine the appropriate section for the new parameter based on its functionality. The config generator page is structured into two primary modes: + - Beginner Mode: Comprises "Essential Configs" and "Important Configs" sections. + - Expert Mode: Includes "Model", "Buffer", "Explorer and Synchronizer", and "Trainer" sections. + +2. Parameter Addition: + Incorporate the new parameter into the relevant section using the `self.get_configs` method within the `ConfigManager` class. + + Example: + ```python + class ConfigManager: + def _expert_buffer_part(self): + self.get_configs("total_epochs", "train_batch_size") + ``` + +3. YAML File Integration: + Locate the appropriate position for the new parameter within the YAML file structure. This should be done in the `generate_config` function and its associated sub-functions. + +4. Parameter Value Assignment: + Utilize `st.session_state` to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML. + + Example: + ```python + class ConfigManager: + def _gen_buffer_config(self): + buffer_config = { + "batch_size": st.session_state["train_batch_size"], + # Additional configuration parameters + } + ``` + +By meticulously following these steps, you can ensure that new parameters are successfully added to the Config Generator page and properly integrated into the configuration system. This process maintains the integrity and functionality of the configuration management framework. + +--- + ## Check Code Style Before submitting the code, make sure it passes the code style check. Follow these steps: diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 9ac2d36f16..80b8992b3b 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -7,22 +7,15 @@ import streamlit as st import yaml -from trinity.common.constants import ( - AlgorithmType, - MonitorType, - PromptType, - StorageType, - SyncMethod, -) -from trinity.common.rewards import REWARD_FUNCTIONS -from trinity.common.workflows.workflow import WORKFLOWS -from trinity.trainer.verl.ray_trainer import AdvantageEstimator +from trinity.common.constants import AlgorithmType, StorageType +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +from trinity.manager.config_registry.trainer_config_manager import use_critic class ConfigManager: def __init__(self): - self._init_default_config() self.unfinished_fields = set() + CONFIG_GENERATORS.set_unfinished_fields(self.unfinished_fields) st.set_page_config(page_title="Trinity-RFT Config Generator", page_icon=":robot:") st.title("Trinity-RFT Config Generator") if "_init_config_manager" not in st.session_state: @@ -44,1319 +37,256 @@ def __init__(self): st.session_state.is_running = False self.generate_config() - def _init_default_config(self): - self.default_config = { - "_init_config_manager": True, - "mode": "both", - "project": "Trinity-RFT", - "exp_name": "qwen2.5-1.5B", - "checkpoint_root_dir": "", - "monitor_type": MonitorType.TENSORBOARD.value, - # Algorithm Configs - "algorithm_type": AlgorithmType.PPO.value, - "_grouped_adv_repeat_times": 2, - "_not_grouped_adv_repeat_times": 1, - "repeat_times": 1, - "gamma": 1.0, - "lam": 1.0, - # Model Configs - "model_path": "", - "critic_model_path": "", - "max_prompt_tokens": 1024, - "max_response_tokens": 1024, - # Cluster Config - "node_num": 1, - "gpu_per_node": 8, - "total_gpu_num": 8, - "trainer_gpu_num": 6, - # Buffer Configs - "total_epochs": 20, - "_train_batch_size_per_gpu": 16, - "train_batch_size": 96, - "buffer_max_retry_times": 3, - "max_retry_interval": 1, - # Taskset Configs - "taskset_path": "", - "taskset_subset_name": None, - "taskset_split": "train", - "taskset_prompt_key": "question", - "taskset_response_key": "answer", - "temperature": 1.0, - "top_p": 1.0, # TODO: to be used - "top_k": -1, # TODO: to be used - "logprobs": 0, - # Eval Taskset Configs - "_eval_tasksets_num": 0, - # Explorer Input Configs - "default_workflow_type": "math_workflow", - "default_reward_fn_type": "math_reward", - "system_prompt": None, - "reply_prefix": None, - # Experience Buffer / DPO Dataset Configs - "_dpo_storage_type": StorageType.FILE.value, - "_not_dpo_storage_type": StorageType.QUEUE.value, - "storage_type": StorageType.QUEUE.value, - "_dpo_experience_buffer_path": "", - "_not_dpo_experience_buffer_path": "", - "experience_buffer_path": "", - "dpo_dataset_train_split": "train", - "dpo_dataset_prompt_type": PromptType.MESSAGES.value, - "dpo_dataset_prompt_key": "prompt", - "dpo_dataset_chosen_key": "chosen", - "dpo_dataset_rejected_key": "rejected", - # SFT Warmup Dataset Configs - "sft_warmup_dataset_path": "", - "sft_warmup_train_split": "train", - "sft_warmup_prompt_type": PromptType.MESSAGES.value, - "sft_warmup_messages_key": "messages", - "sft_warmup_prompt_key": "prompt", - "sft_warmup_response_key": "response", - # TrainerInput Configs - # TODO: read_experience_strategy - "sft_warmup_steps": 0, - # Explorer and Sync Configs - "runner_num": 32, - "max_timeout": 900, - "explorer_max_retry_times": 2, - "eval_interval": 1000, - "eval_on_latest_checkpoint": True, - # Rollout Model Configs - "engine_type": "vllm_async", - "engine_num": 2, - "tensor_parallel_size": 1, - "use_v1": True, - "enforce_eager": True, - "enable_prefix_caching": False, - "enable_chunked_prefill": False, - "gpu_memory_utilization": 0.9, - "dtype": "bfloat16", - "seed": 42, - # TODO: max_prompt_tokens - # TODO: max_response_tokens - # TODO: chat_template - "enable_thinking": False, - "enable_openai_api": False, - # TODO: Auxiliary Models Configs - # Synchronizer Configs - "_not_dpo_sync_method": SyncMethod.NCCL.value, - "sync_method": SyncMethod.NCCL.value, - "sync_interval": 10, - "sync_timeout": 1200, - # Trainer Configs - "trainer_type": "verl", - "_nccl_save_interval": 100, - "save_interval": 100, - # TODO: enable_preview - "_not_dpo_actor_use_kl_loss": True, - "actor_use_kl_loss": True, - "actor_kl_loss_coef": 0.001, - "actor_entropy_coef": 0.001, - "actor_grad_clip": 1.0, - "actor_clip_ratio": 0.2, - # veRL Trainer Configs - "training_args": [ - "balance_batch", - "gradient_checkpointing", - "remove_padding", - "dynamic_bsz", - ], - "ppo_epochs": 1, - "training_strategy": "fsdp", - "param_offload": False, - "optimizer_offload": False, - "resume_mode": "auto", - "resume_from_path": "", - "critic_warmup": 0, - "total_training_steps": None, - "default_hdfs_dir": None, - "remove_previous_ckpt_in_save": False, - "del_local_ckpt_after_load": False, - "max_actor_ckpt_to_keep": None, - "max_critic_ckpt_to_keep": None, - "adv_estimator": "gae", - "norm_adv_by_std_in_grpo": True, - "use_kl_in_reward": False, - "kl_penalty": "low_var_kl", - "kl_ctrl_type": "fixed", - "kl_ctrl_coef": 0.001, - "horizon": 10000, - "target_kl": 0.1, - "actor_ppo_micro_batch_size_per_gpu": 4, - "ref_log_prob_micro_batch_size_per_gpu": 8, - "actor_ulysses_sequence_parallel_size": 1, - "actor_lr": 1e-6, - "actor_warmup_style": "constant", - "actor_lr_warmup_steps_ratio": 0.0, - "actor_tau": 0.0, - "actor_opmd_baseline": "mean", - "actor_use_uid": False, - "actor_kl_loss_type": "low_var_kl", - "actor_checkpoint": ["model", "hf_model", "optimizer", "extra"], - "critic_lr": 1e-6, - "critic_warmup_style": "constant", - "critic_lr_warmup_steps_ratio": 0.0, - "critic_grad_clip": 1.0, - "critic_cliprange_value": 0.5, - "critic_ppo_micro_batch_size_per_gpu": 8, - "critic_ulysses_sequence_parallel_size": 1, - "critic_checkpoint": ["model", "optimizer", "extra"], - } - def reset_session_state(self): - for key, value in self.default_config.items(): + st.session_state["_init_config_manager"] = True + for key, value in CONFIG_GENERATORS.default_config.items(): st.session_state[key] = value def maintain_session_state(self): - for key in self.default_config: + st.session_state["_init_config_manager"] = True + for key in CONFIG_GENERATORS.default_config: st.session_state[key] = st.session_state[key] - eavl_dataset_keys = ["name", "path", "subset_name", "split", "prompt_key", "response_key"] + + eval_dataset_keys = [ + "name", + "path", + "subset_name", + "split", + "prompt_key", + "response_key", + "temperature", + "logprobs", + "n", + ] + last_idx, del_num = 0, 0 for idx in range(st.session_state["_eval_tasksets_num"]): - for key in eavl_dataset_keys: + if st.session_state.get(f"eval_taskset_{idx}_del_flag", False): + del_num += 1 + continue + for key in eval_dataset_keys: full_key = f"eval_taskset_{idx}_{key}" - st.session_state[full_key] = st.session_state[full_key] - - def _set_project(self): - st.text_input("Project", key="project") - - def _set_exp_name(self): - st.text_input("Experiment Name", key="exp_name") - - def _set_monitor_type(self): - st.selectbox( - "Monitor Type", - options=[monitor_type.value for monitor_type in MonitorType], - key="monitor_type", - ) - - def _set_model_path(self): - st.text_input("Model Path", key="model_path") - if not st.session_state["model_path"].strip(): - self.unfinished_fields.add("model_path") - st.warning("Please input model path.") - - def _set_critic_model_path(self): - if st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value: - st.text_input( - "Critic Model Path (defaults to `model_path`)", - key="critic_model_path", - ) - - def _set_checkpoint_root_dir(self): - st.text_input("Checkpoint Root Dir", key="checkpoint_root_dir") - if not st.session_state["checkpoint_root_dir"].strip(): # TODO: may auto generate - self.unfinished_fields.add("checkpoint_root_dir") - st.warning("Please input checkpoint root dir.") - elif not os.path.isabs(st.session_state["checkpoint_root_dir"].strip()): - self.unfinished_fields.add("checkpoint_root_dir") - st.warning("Please input an absolute path.") - - def _set_node_num(self): - st.number_input("Node Num", key="node_num", min_value=1, on_change=self._set_total_gpu_num) - - def _set_gpu_per_node(self): - st.number_input( - "GPU Per Node", - key="gpu_per_node", - min_value=1, - max_value=8, - on_change=self._set_total_gpu_num, - ) - - def _set_total_gpu_num(self): - st.session_state["total_gpu_num"] = ( - st.session_state["gpu_per_node"] * st.session_state["node_num"] - ) - self._set_trainer_gpu_num() - - def _set_trainer_gpu_num(self): - if st.session_state["mode"] == "both": - st.session_state["trainer_gpu_num"] = ( - st.session_state["total_gpu_num"] - - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"] - ) - else: # model == train - st.session_state["trainer_gpu_num"] = st.session_state["total_gpu_num"] - - def _set_max_prompt_tokens(self): - st.number_input("Max Prompt Tokens", key="max_prompt_tokens", min_value=1) - - def _set_max_response_tokens(self): - st.number_input("Max Response Tokens", key="max_response_tokens", min_value=1) - - def _set_total_epochs(self): - st.number_input("Total Epochs", key="total_epochs", min_value=1) - - @property - def _str_for_train_batch_size(self): - trainer_gpu_num_str = ( - "`gpu_per_node * node_num - engine_num * tensor_parallel_size`" - if st.session_state["mode"] == "both" - else "`gpu_per_node * node_num`" - ) - return ( - f"Please ensure that `train_batch_size` can be divided by " - f"{trainer_gpu_num_str} = {st.session_state['trainer_gpu_num']}." - ) - - def _set_train_batch_size(self): - trainer_gpu_num = st.session_state["trainer_gpu_num"] - st.session_state["train_batch_size"] = ( - st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"] - ) - - def on_change(): - st.session_state["_train_batch_size_per_gpu"] = max( - st.session_state["train_batch_size"] // st.session_state["trainer_gpu_num"], 1 - ) - - st.number_input( - "Train Batch Size", - key="train_batch_size", - min_value=trainer_gpu_num, - step=trainer_gpu_num, - help=self._str_for_train_batch_size, - on_change=on_change, - ) - - def _check_train_batch_size(self): - if st.session_state["train_batch_size"] % st.session_state["trainer_gpu_num"] != 0: - self.unfinished_fields.add("train_batch_size") - st.warning(self._str_for_train_batch_size) - - def _set_taskset_path(self): - st.text_input("Taskset Path", key="taskset_path") - if not st.session_state["taskset_path"].strip(): - self.unfinished_fields.add("taskset_path") - st.warning("Please input taskset path.") - - def _set_system_prompt(self): - st.text_area( - "System Prompt", - key="system_prompt", - placeholder="System prompt is used to guide the model behavior.", - ) - - def _set_reply_prefix(self): - st.text_area( - "Assistant Reply Prefix", - key="reply_prefix", - placeholder="""Assistant reply prefix is used to specify the initial content of model reply, """ - """and a common setting is: \nLet me solve this step by step. """, - ) - - def _set_taskset_args(self): - if st.session_state["taskset_path"] and "://" not in st.session_state["taskset_path"]: - subset_name_col, split_col = st.columns(2) - subset_name_col.text_input( - "Subset Name :orange-badge[(Needs review)]", - key="taskset_subset_name", - help="The subset name used for `datasets.load_datasets`, see " - "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", - ) - split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split") - prompt_key_col, response_key_col = st.columns(2) - prompt_key_col.text_input( - "Prompt Key :orange-badge[(Needs review)]", key="taskset_prompt_key" - ) - response_key_col.text_input( - "Response Key :orange-badge[(Needs review)]", key="taskset_response_key" - ) - self._set_configs_with_st_columns(["temperature", "logprobs"]) - - def _set_eval_taskset_idx(self, idx): # TODO: add delete - st.text_input( - "Taskset Name", - key=f"eval_taskset_{idx}_name", - ) - st.text_input( - "Eval Taskset Path", - key=f"eval_taskset_{idx}_path", - ) - if not st.session_state[f"eval_taskset_{idx}_path"].strip(): - st.warning("Please input the taskset path, or it will be ignored.") - subset_name_col, split_col = st.columns(2) - subset_name_col.text_input( - "Subset Name :orange-badge[(Needs review)]", - key=f"eval_taskset_{idx}_subset_name", - help="The subset name used for `datasets.load_datasets`, see " - "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", - ) - split_col.text_input( - "Eval Split :orange-badge[(Needs review)]", - key=f"eval_taskset_{idx}_split", - ) - prompt_key_col, response_key_col = st.columns(2) - prompt_key_col.text_input( - "Prompt Key :orange-badge[(Needs review)]", - key=f"eval_taskset_{idx}_prompt_key", - ) - response_key_col.text_input( - "Response Key :orange-badge[(Needs review)]", - key=f"eval_taskset_{idx}_response_key", - ) - - def _set_eval_tasksets(self): - if st.button("Add Eval Taskset"): - st.session_state["_eval_tasksets_num"] += 1 - if st.session_state["_eval_tasksets_num"] > 0: - tabs = st.tabs( - [f"Eval Taskset {i + 1}" for i in range(st.session_state["_eval_tasksets_num"])] - ) - for idx, tab in enumerate(tabs): - with tab: - self._set_eval_taskset_idx(idx) - - def _set_default_workflow_type(self): - st.selectbox( - "Default Workflow Type :orange-badge[(Needs review)]", - WORKFLOWS.modules.keys(), - key="default_workflow_type", - help=r"""`simple_workflow`: call 'model.chat()' to get responses. - -`math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses. - -Other workflows: conduct multi-turn task for the given dataset. -""", - ) - - def _set_default_reward_fn_type(self): - st.selectbox( - "Default Reward Fn Type :orange-badge[(Needs review)]", - REWARD_FUNCTIONS.modules.keys(), - key="default_reward_fn_type", - help=r"""`accuracy_reward`: check the accuracy for math problems. - -`format_reward`: check if the response matches the format (default: `** *`). - -`math_reward`: `accuracy_reward` (1 or 0) + `format_reward` (+0.1 or -0.1). -""", - ) - - def _set_storage_type(self): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["storage_type"] = st.session_state["_dpo_storage_type"] - storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] - else: - st.session_state["storage_type"] = st.session_state["_not_dpo_storage_type"] - storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value] - - def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["_dpo_storage_type"] = st.session_state["storage_type"] - else: - st.session_state["_not_dpo_storage_type"] = st.session_state["storage_type"] - - st.selectbox( - "Storage Type", - storage_candidates, - key="storage_type", - on_change=on_change, - ) - - def _set_experience_buffer_path(self): # TODO - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["experience_buffer_path"] = st.session_state[ - "_dpo_experience_buffer_path" - ] - title = "DPO Dataset Path" - help_msg = r"""This path to DPO dataset, - -if `storage_type == StorageType.FILE`, this should be a path to a file, - -if `storage_type == StorageType.SQL`, this should be a path to database.""" - else: - st.session_state["experience_buffer_path"] = st.session_state[ - "_not_dpo_experience_buffer_path" - ] - title = "Experience Buffer Path" - help_msg = r"""This path is used for `trainer`, - -if `storage_type == StorageType.QUEUE`, default to `None`, - -if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_root_dir, '.cache', project_name, experiment_name)}/data.db`.""" - - def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["_dpo_experience_buffer_path"] = st.session_state[ - "experience_buffer_path" - ] - else: - st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[ - "experience_buffer_path" - ] - - st.text_input( - title, - key="experience_buffer_path", - help=help_msg, - on_change=on_change, - ) - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - if not st.session_state["experience_buffer_path"].strip(): - self.unfinished_fields.add("experience_buffer_path") - st.warning("Please input DPO dataset path.") - - def _set_buffer_max_retry_times(self): - st.number_input("Max Retry Times", key="buffer_max_retry_times", min_value=1) - - def _set_max_retry_interval(self): - st.number_input("Max Retry Interval", key="max_retry_interval", min_value=1) - - def _set_dpo_dataset_kwargs(self): - dpo_dataset_train_split_col, dpo_dataset_prompt_type_col = st.columns(2) - dpo_dataset_train_split_col.text_input( - "DPO Dataset Train Split :orange-badge[(Needs review)]", key="dpo_dataset_train_split" - ) - dpo_dataset_prompt_type_col.selectbox( - "DPO Dataset Prompt Type :orange-badge[(Needs review)]", - [prompt_type.value for prompt_type in PromptType], - key="dpo_dataset_prompt_type", - ) - - ( - dpo_dataset_prompt_key_col, - dpo_dataset_chosen_key_col, - dpo_dataset_rejected_key_col, - ) = st.columns(3) - dpo_dataset_prompt_key_col.text_input( - "DPO Dataset Prompt Key :orange-badge[(Needs review)]", key="dpo_dataset_prompt_key" - ) - dpo_dataset_chosen_key_col.text_input( - "DPO Dataset Chosen Key :orange-badge[(Needs review)]", key="dpo_dataset_chosen_key" - ) - dpo_dataset_rejected_key_col.text_input( - "DPO Dataset Rejected Key :orange-badge[(Needs review)]", - key="dpo_dataset_rejected_key", - ) - - def _check_sft_warmup_dataset_path(self): - if st.session_state["sft_warmup_steps"]: - if not st.session_state["sft_warmup_dataset_path"].strip(): - self.unfinished_fields.add("sft_warmup_dataset_path") - st.warning("Please input SFT warmup dataset path when `sft_warmup_steps` is not 0") - - def _set_sft_warmup_dataset_path(self): - st.text_input("SFT Warmup Dataset Path", key="sft_warmup_dataset_path") - self._check_sft_warmup_dataset_path() - - def _set_sft_warmup_dataset_args(self): - if ( - st.session_state["sft_warmup_dataset_path"] - and "://" not in st.session_state["sft_warmup_dataset_path"] - ): # TODO - ( - sft_warmup_train_split_col, - sft_warmup_prompt_type_col, - ) = st.columns(2) - sft_warmup_train_split_col.text_input( - "SFT Dataset Train Split :orange-badge[(Needs review)]", - key="sft_warmup_train_split", - ) - sft_warmup_prompt_type_col.selectbox( - "SFT Dataset Prompt Type :orange-badge[(Needs review)]", - [prompt_type.value for prompt_type in PromptType], - key="sft_warmup_prompt_type", - ) - ( - sft_warmup_messages_key_col, - sft_warmup_prompt_key_col, - sft_warmup_response_key_col, - ) = st.columns( - 3 - ) # TODO: select by prompt type - sft_warmup_messages_key_col.text_input( - "SFT Dataset Messages Key :orange-badge[(Needs review)]", - key="sft_warmup_messages_key", - ) - sft_warmup_prompt_key_col.text_input( - "SFT Dataset Prompt Key :orange-badge[(Needs review)]", key="sft_warmup_prompt_key" - ) - sft_warmup_response_key_col.text_input( - "SFT Dataset Response Key :orange-badge[(Needs review)]", - key="sft_warmup_response_key", - ) - - def _set_engine_type(self): - st.selectbox("Explorer Engine Type", ["vllm_async", "vllm"], key="engine_type") - - @property - def _str_for_engine_num_and_tp_size(self): - return r"""and it must meet the following constraints: -```python -assert engine_num * tensor_parallel_size < gpu_per_node * node_num -if node_num > 1: - assert gpu_per_node % tensor_parallel_size == 0 - assert engine_num * tensor_parallel_size % gpu_per_node == 0 -```""" - - def _set_engine_num(self): - total_gpu_num = st.session_state["total_gpu_num"] - max_engine_num = (total_gpu_num - 1) // st.session_state["tensor_parallel_size"] - if st.session_state["engine_num"] > max_engine_num: - st.session_state["engine_num"] = max_engine_num - self._set_trainer_gpu_num() - st.number_input( - "Engine Num", - key="engine_num", - min_value=1, - max_value=max_engine_num, - help=f"`engine_num` is used to set the quantity of inference engines, " - f"{self._str_for_engine_num_and_tp_size}", - on_change=self._set_trainer_gpu_num, - ) - - def _set_tensor_parallel_size(self): - total_gpu_num = st.session_state["total_gpu_num"] - max_tensor_parallel_size = (total_gpu_num - 1) // st.session_state["engine_num"] - if st.session_state["tensor_parallel_size"] > max_tensor_parallel_size: - st.session_state["tensor_parallel_size"] = max_tensor_parallel_size - self._set_trainer_gpu_num() - st.number_input( - "Tensor Parallel Size", - key="tensor_parallel_size", - min_value=1, - max_value=max_tensor_parallel_size, - help=f"`tensor_parallel_size` is used to set the tensor parallel size of inference engines, " - f"{self._str_for_engine_num_and_tp_size}", - on_change=self._set_trainer_gpu_num, - ) - - def _check_engine_num_and_tp_size(self): - node_num = st.session_state["node_num"] - gpu_per_node = st.session_state["gpu_per_node"] - engine_num = st.session_state["engine_num"] - tensor_parallel_size = st.session_state["tensor_parallel_size"] - if node_num > 1: - if gpu_per_node % tensor_parallel_size != 0: - self.unfinished_fields.add("tensor_parallel_size") - st.warning( - "Please ensure that `tensor_parallel_size` is a factor of `gpu_per_node` when `node_num > 1`." - ) - if engine_num * tensor_parallel_size % gpu_per_node != 0: - self.unfinished_fields.add("engine_num") - st.warning( - "Please ensure that `engine_num * tensor_parallel_size` can be divided by `gpu_per_node` when `node_num > 1`." - ) - - def _set_repeat_times(self): # TODO - grouped_adv_algorithms = [ - AlgorithmType.GRPO.value, - AlgorithmType.OPMD.value, # TODO: may add rloo + last_full_key = f"eval_taskset_{last_idx}_{key}" + st.session_state[last_full_key] = st.session_state[full_key] + last_idx += 1 + st.session_state["_eval_tasksets_num"] -= del_num + + auxiliary_model_keys = [ + "model_path", + "engine_type", + "engine_num", + "tensor_parallel_size", + "gpu_memory_utilization", + "dtype", + "seed", + "use_v1", + "enforce_eager", + "enable_prefix_caching", + "enable_chunked_prefill", + "enable_thinking", + "enable_openai_api", ] - if st.session_state["algorithm_type"] in grouped_adv_algorithms: - min_repeat_times = 2 - st.session_state["repeat_times"] = st.session_state["_grouped_adv_repeat_times"] - else: - min_repeat_times = 1 - st.session_state["repeat_times"] = st.session_state["_not_grouped_adv_repeat_times"] - - def on_change(): - if st.session_state["algorithm_type"] in grouped_adv_algorithms: - st.session_state["_grouped_adv_repeat_times"] = st.session_state["repeat_times"] - else: - st.session_state["_not_grouped_adv_repeat_times"] = st.session_state["repeat_times"] - - st.number_input( - "Repeat Times", - key="repeat_times", - min_value=min_repeat_times, - help="`repeat_times` is used to set how many experiences each task can generate, " - "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", - on_change=on_change, - ) - - def _set_sync_method(self): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["sync_method"] = SyncMethod.CHECKPOINT.value - disabled = True - else: - st.session_state["sync_method"] = st.session_state["_not_dpo_sync_method"] - disabled = False - - def on_change(): - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: - st.session_state["_not_dpo_sync_method"] = st.session_state["sync_method"] - - st.selectbox( - "Sync Method", - [sync_method.value for sync_method in SyncMethod], - key="sync_method", - help="""`nccl`: the explorer and trainer sync model weights once every `sync_interval` steps. - -`checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_interval`.""", - disabled=disabled, - on_change=on_change, - ) - - def _set_sync_interval(self): - st.number_input( - "Sync Interval", - key="sync_interval", - min_value=1, - help="""The step interval at which the `explorer` and `trainer` synchronize model weight.""", - ) - - def _set_sync_timeout(self): - st.number_input( - "Sync Timeout", - key="sync_timeout", - min_value=1, - help="The timeout value for the synchronization operation.", - ) - - def _set_runner_num(self): - st.number_input("Runner Num", key="runner_num", min_value=1) - - def _set_dtype(self): - st.selectbox("Dtype", ["float16", "bfloat16", "float32"], key="dtype") - - def _set_temperature(self): - st.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) - - def _set_top_p(self): - st.number_input("Top-p", key="top_p", min_value=0.0, max_value=1.0) - - def _set_top_k(self): - st.number_input( - "Top-k", - key="top_k", - min_value=-1, - max_value=512, - help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.", - ) - - def _set_seed(self): - st.number_input("Seed", key="seed", step=1) - - def _set_logprobs(self): - st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) - - def _set_use_v1(self): - st.checkbox("Use V1 Engine", key="use_v1") - - def _set_enable_prefix_caching(self): - st.checkbox("Prefix Caching", key="enable_prefix_caching") - - def _set_enforce_eager(self): - st.checkbox("Enforce Eager", key="enforce_eager") - - def _set_gpu_memory_utilization(self): - st.number_input( - "GPU Memory Utilization", key="gpu_memory_utilization", min_value=0.0, max_value=1.0 - ) - - def _set_enable_chunked_prefill(self): - st.checkbox("Chunked Prefill", key="enable_chunked_prefill") - - def _set_enable_thinking(self): - st.checkbox("Enable Thinking For Qwen3", key="enable_thinking") - - def _set_enable_openai_api(self): - st.checkbox("Enable OpenAI API", key="enable_openai_api") - - def _set_max_timeout(self): - st.number_input("Max Timeout", key="max_timeout", min_value=0) - - def _set_explorer_max_retry_times(self): - st.number_input("Explorer Max Retry Times", key="explorer_max_retry_times", min_value=0) - - def _set_trainer_type(self): - st.selectbox("Trainer Type", ["verl"], key="trainer_type") - - def _set_algorithm_type(self): - def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.PPO.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value - elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["mode"] = "train" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - else: # TODO: add more algorithms - pass - self._set_trainer_gpu_num() - - st.selectbox( - "Algorithm Type", - [ - AlgorithmType.PPO.value, - AlgorithmType.GRPO.value, - AlgorithmType.DPO.value, - AlgorithmType.OPMD.value, - ], - key="algorithm_type", - on_change=on_change, - ) - - def _set_sft_warmup_steps(self): - st.number_input("SFT Warmup Steps", key="sft_warmup_steps", min_value=0) - - def _set_eval_interval(self): - st.number_input("Eval Interval", key="eval_interval", min_value=1) - - def _set_eval_on_latest_checkpoint(self): - st.checkbox("Eval on Latest Checkpoint", key="eval_on_latest_ckp") - - def _set_training_args(self): - st.multiselect( - "Training Args", - [ - "balance_batch", - "gradient_checkpointing", - "remove_padding", - "dynamic_bsz", - ], - key="training_args", - ) - - def _set_save_interval(self): - if ( - st.session_state["algorithm_type"] == AlgorithmType.DPO.value - or st.session_state["sync_method"] == SyncMethod.NCCL.value - ): - st.session_state["save_interval"] = st.session_state["_nccl_save_interval"] - freeze_save_interval = False - else: - st.session_state["save_interval"] = st.session_state["sync_interval"] - freeze_save_interval = True - - def on_change(): - if ( - st.session_state["algorithm_type"] == AlgorithmType.DPO.value - or st.session_state["sync_method"] == SyncMethod.NCCL.value - ): - st.session_state["_nccl_save_interval"] = st.session_state["save_interval"] - - st.number_input( - "Save Interval", - key="save_interval", - min_value=1, - help="Set to `sync_interval` when `algorithm_type != DPO && sync_method == checkpoint`", - disabled=freeze_save_interval, - on_change=on_change, - ) - - def _set_ppo_epochs(self): - st.number_input("PPO Epochs", key="ppo_epochs", min_value=1) - - def _set_training_strategy(self): - st.selectbox( - "Training Strategy", - ["fsdp", "megatron"], - key="training_strategy", - help="megatron is not tested", - ) - - def _set_param_offload(self): - st.checkbox("FSDP Param Offload", key="param_offload") - - def _set_optimizer_offload(self): - st.checkbox("FSDP Optimizer Offload", key="optimizer_offload") - - def _set_resume_mode(self): - st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], key="resume_mode") - - def _set_resume_from_path(self): - if st.session_state["resume_mode"] == "resume_path": - st.text_input("Resume Path", key="resume_from_path") - if ( - not st.session_state["resume_from_path"].strip() - or "global_step_" not in st.session_state["resume_from_path"] - ): - self.unfinished_fields.add("resume_from_path") - st.warning("Please input a valid resume path when `resume_mode == resume_path`") - - def _set_critic_warmup(self): - st.number_input("Critic Warmup Steps", key="critic_warmup", min_value=0) - - def _set_total_training_steps(self): - st.number_input("Total Training Steps", key="total_training_steps", min_value=1) - - def _set_default_hdfs_dir(self): - st.text_input("Default HDFS Dir", key="default_hdfs_dir") - - def _set_remove_previous_ckpt_in_save(self): - st.checkbox("Remove Previous Checkpoint in Save", key="remove_previous_ckpt_in_save") - - def _set_del_local_ckpt_after_load(self): - st.checkbox("Delete Local Checkpoint After Load", key="del_local_ckpt_after_load") - - def _set_max_actor_ckpt_to_keep(self): - st.number_input("Max Actor Checkpoint to Keep", key="max_actor_ckpt_to_keep", min_value=1) - - def _set_max_critic_ckpt_to_keep(self): - st.number_input("Max Critic Checkpoint to Keep", key="max_critic_ckpt_to_keep", min_value=1) - - def _set_gamma(self): - st.number_input(r"Gamma :blue-badge[$\gamma$]", key="gamma") - - def _set_lam(self): - st.number_input(r"Lambda :blue-badge[$\lambda$]", key="lam") - - def _set_norm_adv_by_std_in_grpo(self): - st.checkbox("Norm Adv by Std in GRPO", key="norm_adv_by_std_in_grpo") - - def _set_use_kl_in_reward(self): - st.checkbox("Use KL in Reward", key="use_kl_in_reward") - - def _set_kl_penalty(self): - st.selectbox("KL Penalty", ["kl", "abs", "mse", "low_var_kl"], key="kl_penalty") - - def _set_kl_ctrl_type(self): - st.selectbox("KL Ctrl Type", ["fixed", "adaptive"], key="kl_ctrl_type") - - def _set_kl_ctrl_coef(self): - st.number_input("KL Ctrl Coef", key="kl_ctrl_coef", format="%.1e") - - def _set_horizon(self): - st.number_input("Horizon", key="horizon", min_value=1.0) - - def _set_target_kl(self): - st.number_input("Target KL", key="target_kl", format="%.1e") - - def _set_actor_ppo_micro_batch_size_per_gpu(self): - st.session_state["actor_ppo_micro_batch_size_per_gpu"] = min( - st.session_state["actor_ppo_micro_batch_size_per_gpu"], - st.session_state["_train_batch_size_per_gpu"], - ) - st.number_input( - "Micro Batch Size Per GPU :blue-badge[(Actor)]", - key="actor_ppo_micro_batch_size_per_gpu", - min_value=1, - max_value=st.session_state["_train_batch_size_per_gpu"], - ) - - def _set_ref_log_prob_micro_batch_size_per_gpu(self): - st.session_state["ref_log_prob_micro_batch_size_per_gpu"] = min( - st.session_state["ref_log_prob_micro_batch_size_per_gpu"], - st.session_state["_train_batch_size_per_gpu"], - ) - st.number_input( - "Micro Batch Size Per GPU :blue-badge[(Ref)]", - key="ref_log_prob_micro_batch_size_per_gpu", - min_value=1, - max_value=st.session_state["_train_batch_size_per_gpu"], - ) - - def _set_actor_ulysses_sequence_parallel_size(self): - st.number_input( - "Ulysses Sequence Parallel Size", - key="actor_ulysses_sequence_parallel_size", - min_value=1, - max_value=8, - ) - - def _set_actor_lr(self): - st.number_input( - "Learning Rate :blue-badge[(Actor)]", - key="actor_lr", - min_value=1e-7, - max_value=1e-3, - format="%.1e", - ) - - def _set_actor_warmup_style(self): - st.selectbox( - "LR Warmup Style :blue-badge[(Actor)]", - ["constant", "cosine"], - key="actor_warmup_style", - ) - - def _set_actor_lr_warmup_steps_ratio(self): - st.number_input( - "LR Warmup Steps Ratio :blue-badge[(Actor)]", - key="actor_lr_warmup_steps_ratio", - min_value=0.0, - max_value=1.0, - ) - - def _set_actor_grad_clip(self): - st.number_input( - "Grad Clip :blue-badge[(Actor)]", - key="actor_grad_clip", - min_value=0.0, - max_value=1.0, - help="Clipping by Norm", - ) - - def _set_actor_clip_ratio(self): - st.number_input( - r"Clip Ratio :blue-badge[$\epsilon$]", - key="actor_clip_ratio", - min_value=0.0, - max_value=1.0, - ) - - def _set_actor_entropy_coef(self): - st.number_input( - "Entropy Coeff", - key="actor_entropy_coef", - min_value=0.0, - max_value=1.0, - format="%.1e", - ) - - def _set_actor_use_kl_loss(self): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["actor_use_kl_loss"] = True - else: - st.session_state["actor_use_kl_loss"] = st.session_state["_not_dpo_actor_use_kl_loss"] - - def on_change(): - st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[ - "actor_use_kl_loss" - ] - - st.checkbox("Use KL Loss", key="actor_use_kl_loss", on_change=on_change) - - def _set_actor_kl_loss_coef(self): - st.number_input( - r"KL Loss Coef :blue-badge[$\beta$]", - key="actor_kl_loss_coef", - min_value=0.0, - max_value=1.0, - format="%.1e", - ) - - def _set_actor_kl_loss_type(self): - st.selectbox( - "KL Loss Type", - ["kl", "abs", "mse", "low_var_kl"], - key="actor_kl_loss_type", - ) - - def _set_actor_tau(self): - st.number_input( - "Tau for OPMD", - key="actor_tau", - min_value=0.0, - format="%.1e", - ) - - def _set_actor_opmd_baseline(self): - st.selectbox( - "OPMD Baseline", - ["mean", "logavgexp"], - key="actor_opmd_baseline", - ) - - def _set_actor_use_uid(self): - st.checkbox("Use UID for OPMD", key="actor_use_uid") - - def _set_actor_checkpoint(self): - st.multiselect( - "Checkpoint", - ["model", "hf_model", "optimizer", "extra"], - key="actor_checkpoint", - ) - - def _set_critic_ppo_micro_batch_size_per_gpu(self): - st.session_state["critic_ppo_micro_batch_size_per_gpu"] = min( - st.session_state["critic_ppo_micro_batch_size_per_gpu"], - st.session_state["_train_batch_size_per_gpu"], - ) - st.number_input( - "Micro Batch Size Per GPU :blue-badge[(Critic)]", - key="critic_ppo_micro_batch_size_per_gpu", - min_value=1, - max_value=st.session_state["_train_batch_size_per_gpu"], - ) - - def _set_critic_ulysses_sequence_parallel_size(self): - st.number_input( - "Ulysses Sequence Parallel Size", - key="critic_ulysses_sequence_parallel_size", - min_value=1, - max_value=8, - ) - - def _set_critic_lr(self): - st.number_input( - "Learning Rate :blue-badge[(Critic)]", - key="critic_lr", - min_value=1e-7, - max_value=1e-3, - format="%.1e", - ) - - def _set_critic_warmup_style(self): - st.selectbox( - "LR Warmup Style :blue-badge[(Critic)]", - ["constant", "cosine"], - key="critic_warmup_style", - ) - - def _set_critic_lr_warmup_steps_ratio(self): - st.number_input( - "LR Warmup Steps Ratio :blue-badge[(Critic)]", - key="critic_lr_warmup_steps_ratio", - min_value=0.0, - max_value=1.0, - ) - - def _set_critic_grad_clip(self): - st.number_input( - "Grad Clip :blue-badge[(Critic)]", - key="critic_grad_clip", - min_value=0.0, - max_value=1.0, - help="Clipping by Norm", - ) - - def _set_critic_cliprange_value(self): - st.number_input( - "Cliprange Value", - key="critic_cliprange_value", - min_value=0.0, - max_value=1.0, - ) - - def _set_critic_checkpoint(self): - st.multiselect( - "Checkpoint", - ["model", "hf_model", "optimizer", "extra"], - key="critic_checkpoint", - ) - - def _set_configs_with_st_columns( - self, config_names: List[str], columns_config: List[int] = None - ): - if columns_config is None: - columns_config = len(config_names) - columns = st.columns(columns_config) - for col, config_name in zip(columns, config_names): - with col: - getattr(self, f"_set_{config_name}")() + last_idx, del_num = 0, 0 + for idx in range(st.session_state["_auxiliary_models_num"]): + if st.session_state.get(f"auxiliary_model_{idx}_del_flag", False): + del_num += 1 + continue + for key in auxiliary_model_keys: + full_key = f"auxiliary_model_{idx}_{key}" + last_full_key = f"auxiliary_model_{last_idx}_{key}" + st.session_state[last_full_key] = st.session_state[full_key] + last_idx += 1 + st.session_state["_auxiliary_models_num"] -= del_num + + def get_configs(self, *config_names: str, columns_spec: List[int] = None): + CONFIG_GENERATORS.get_configs(*config_names, columns_spec=columns_spec) def beginner_mode(self): st.header("Essential Configs") - self._set_configs_with_st_columns(["project", "exp_name"], columns_config=[1, 3]) + self.get_configs("project", "exp_name", columns_spec=[1, 2]) - self._set_model_path() + self.get_configs("model_path") - self._set_checkpoint_root_dir() + self.get_configs("checkpoint_root_dir") - self._set_taskset_path() + if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + self.get_configs("taskset_path") + else: + self.get_configs("experience_buffer_path") - self._set_configs_with_st_columns(["algorithm_type", "sft_warmup_steps", "monitor_type"]) + self.get_configs("algorithm_type", "sft_warmup_steps", "monitor_type") if st.session_state["sft_warmup_steps"] > 0: - self._set_sft_warmup_dataset_path() + self.get_configs("sft_warmup_dataset_path") st.header("Important Configs") - self._set_configs_with_st_columns( - ["node_num", "gpu_per_node", "engine_num", "tensor_parallel_size"] - if st.session_state["mode"] == "both" - else ["node_num", "gpu_per_node"] - ) - self._check_engine_num_and_tp_size() + self.get_configs("node_num", "gpu_per_node", "engine_num", "tensor_parallel_size") - self._set_configs_with_st_columns( - ["total_epochs", "train_batch_size", "ppo_epochs", "repeat_times"] - if st.session_state["mode"] == "both" - else ["total_epochs", "train_batch_size", "ppo_epochs"] - ) - self._check_train_batch_size() + self.get_configs("total_epochs", "train_batch_size", "ppo_epochs", "repeat_times") - self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"]) + self.get_configs("storage_type", "max_prompt_tokens", "max_response_tokens") - self._set_configs_with_st_columns( - ["sync_interval", "eval_interval", "save_interval"] - if st.session_state["mode"] == "both" - else ["eval_interval", "save_interval"] - ) + self.get_configs("sync_interval", "eval_interval", "save_interval") if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: - self._set_taskset_args() + self.get_configs("taskset_args") else: - self._set_dpo_dataset_kwargs() + self.get_configs("dpo_dataset_kwargs") if st.session_state["sft_warmup_steps"] > 0: - self._set_sft_warmup_dataset_args() + self.get_configs("sft_warmup_dataset_args") - self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"]) + self.get_configs("default_workflow_type", "default_reward_fn_type") - self._set_actor_use_kl_loss() - if st.session_state["actor_use_kl_loss"]: - self._set_configs_with_st_columns(["actor_kl_loss_coef", "actor_kl_loss_type"]) + self.get_configs("actor_use_kl_loss") + self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type") - self._set_configs_with_st_columns( - [ - "actor_ppo_micro_batch_size_per_gpu", - "actor_lr", - "ref_log_prob_micro_batch_size_per_gpu", - ] + self.get_configs( + "actor_ppo_micro_batch_size_per_gpu", + "actor_lr", + "ref_log_prob_micro_batch_size_per_gpu", ) - use_critic = ( - st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value - ) # TODO: may apply to expert mode - if use_critic: - self._set_configs_with_st_columns(["critic_ppo_micro_batch_size_per_gpu", "critic_lr"]) + self.get_configs("critic_ppo_micro_batch_size_per_gpu", "critic_lr") def _expert_model_part(self): - self._set_configs_with_st_columns(["project", "exp_name"], columns_config=[1, 3]) + self.get_configs("project", "exp_name", columns_spec=[1, 2]) - self._set_model_path() - self._set_critic_model_path() + self.get_configs("model_path") + self.get_configs("critic_model_path") - self._set_checkpoint_root_dir() + self.get_configs("checkpoint_root_dir") - self._set_configs_with_st_columns(["monitor_type", "node_num", "gpu_per_node"]) - self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"]) + self.get_configs("monitor_type", "node_num", "gpu_per_node") + self.get_configs("max_prompt_tokens", "max_response_tokens") def _expert_buffer_part(self): - self._set_configs_with_st_columns(["total_epochs", "train_batch_size"]) - self._check_train_batch_size() + self.get_configs("total_epochs", "train_batch_size") - self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"]) - self._set_system_prompt() - self._set_reply_prefix() + self.get_configs("default_workflow_type", "default_reward_fn_type") + self.get_configs("system_prompt") + self.get_configs("reply_prefix") if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: with st.expander("Taskset Configs", expanded=True): - self._set_taskset_path() - self._set_taskset_args() + self.get_configs("taskset_path") + self.get_configs("taskset_args") else: with st.expander("DPO Dataset Configs", expanded=True): - self._set_experience_buffer_path() - self._set_dpo_dataset_kwargs() + self.get_configs("experience_buffer_path") + self.get_configs("storage_type") + self.get_configs("dpo_dataset_kwargs") with st.expander("Eval Tasksets Configs", expanded=True): - self._set_eval_tasksets() + self.get_configs("eval_tasksets") with st.expander("SFT Dataset Configs"): - self._set_sft_warmup_dataset_path() - self._set_sft_warmup_dataset_args() + self.get_configs("sft_warmup_dataset_path") + self.get_configs("sft_warmup_dataset_args") if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: with st.expander("Experiences Buffer Configs", expanded=True): - self._set_storage_type() - self._set_experience_buffer_path() + self.get_configs("storage_type") + self.get_configs("experience_buffer_path") self.buffer_advanced_tab = st.expander("Advanced Config") with self.buffer_advanced_tab: - self._set_configs_with_st_columns(["buffer_max_retry_times", "max_retry_interval"]) + self.get_configs("buffer_max_retry_times", "max_retry_interval") def _expert_explorer_part(self): - self._set_configs_with_st_columns(["sync_method", "sync_interval", "sync_timeout"]) - - self._set_configs_with_st_columns( - [ - "runner_num", - "max_timeout", - "explorer_max_retry_times", - ] - ) + self.get_configs("sync_method", "sync_interval", "sync_timeout") - self._set_configs_with_st_columns(["eval_interval", "eval_on_latest_checkpoint"]) + self.get_configs("runner_num", "max_timeout", "explorer_max_retry_times", "eval_interval") + + self.get_configs("eval_on_latest_checkpoint") with st.expander("Rollout Model Config", expanded=True): - self._set_configs_with_st_columns(["engine_type", "engine_num", "tensor_parallel_size"]) - self._check_engine_num_and_tp_size() + self.get_configs("engine_type", "engine_num", "tensor_parallel_size") - self._set_configs_with_st_columns(["gpu_memory_utilization", "dtype", "seed"]) + self.get_configs("gpu_memory_utilization", "dtype", "seed") - self._set_configs_with_st_columns( - ["use_v1", "enforce_eager", "enable_prefix_caching", "enable_chunked_prefill"] + self.get_configs( + "use_v1", "enforce_eager", "enable_prefix_caching", "enable_chunked_prefill" ) - self._set_configs_with_st_columns(["enable_thinking", "enable_openai_api"]) + self.get_configs("enable_thinking", "enable_openai_api") - with st.expander("Auxiliary Models", expanded=True): # TODO - pass + with st.expander("Auxiliary Models", expanded=True): + self.get_configs("auxiliary_models") def _expert_trainer_part(self): - self._set_configs_with_st_columns(["algorithm_type", "gamma", "lam"]) - self._set_configs_with_st_columns(["repeat_times", "save_interval"]) - self._check_sft_warmup_dataset_path() + self.get_configs("algorithm_type", "gamma", "lam") + self.get_configs("repeat_times", "save_interval") + self.get_configs("enable_preview") if st.session_state["trainer_type"] == "verl": self._expert_verl_trainer_part() - def _expert_verl_trainer_part(self): - rl_training_tab, rl_algorithm_tab, actor_ref_tab, critic_tab = st.tabs( - [ - "RL Training Config", - "RL Algorithm Config", - "Actor and Ref Config", - "Critic Config", - ] - ) - with rl_training_tab: - st.subheader("RL Training Config") - self._set_training_args() + def _expert_verl_training_part(self): + st.subheader("RL Training Config") + self.get_configs("training_args") - self._set_configs_with_st_columns(["ppo_epochs", "training_strategy", "resume_mode"]) + self.get_configs("ppo_epochs", "training_strategy", "resume_mode") - if st.session_state["training_strategy"] == "fsdp": - self._set_configs_with_st_columns(["param_offload", "optimizer_offload"]) - self._set_resume_from_path() + self.get_configs("param_offload", "optimizer_offload") + self.get_configs("resume_from_path") - with st.expander("Advanced Config"): - self._set_configs_with_st_columns(["critic_warmup", "total_training_steps"]) + with st.expander("Advanced Config"): + self.get_configs("critic_warmup", "total_training_steps") - self._set_default_hdfs_dir() + self.get_configs("default_hdfs_dir") - self._set_configs_with_st_columns( - ["remove_previous_ckpt_in_save", "del_local_ckpt_after_load"] - ) + self.get_configs("remove_previous_ckpt_in_save", "del_local_ckpt_after_load") - self._set_configs_with_st_columns( - ["max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep"] - ) + self.get_configs("max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep") - with rl_algorithm_tab: - st.subheader("RL Algorithm Config") - self._set_configs_with_st_columns(["norm_adv_by_std_in_grpo", "use_kl_in_reward"]) - self._set_configs_with_st_columns(["kl_penalty", "kl_ctrl_type", "kl_ctrl_coef"]) - self._set_configs_with_st_columns(["horizon", "target_kl"]) + def _expert_verl_algorithm_part(self): + st.subheader("RL Algorithm Config") + self.get_configs("norm_adv_by_std_in_grpo", "use_kl_in_reward") + self.get_configs("kl_penalty", "kl_ctrl_type", "kl_ctrl_coef") + self.get_configs("horizon", "target_kl") - with actor_ref_tab: - st.subheader("Actor Model Config") - self._set_configs_with_st_columns( - [ - "actor_ppo_micro_batch_size_per_gpu", - "ref_log_prob_micro_batch_size_per_gpu", - "actor_ulysses_sequence_parallel_size", - ] - ) + def _expert_verl_actor_part(self): + st.subheader("Actor Model Config") + self.get_configs( + "actor_ppo_micro_batch_size_per_gpu", + "ref_log_prob_micro_batch_size_per_gpu", + "actor_ulysses_sequence_parallel_size", + ) - self._set_configs_with_st_columns( - ["actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio"] - ) + self.get_configs("actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio") - self._set_configs_with_st_columns( - ["actor_grad_clip", "actor_clip_ratio", "actor_entropy_coef"] - ) + self.get_configs("actor_grad_clip", "actor_clip_ratio", "actor_entropy_coef") - self._set_actor_use_kl_loss() - if st.session_state["actor_use_kl_loss"]: - self._set_configs_with_st_columns(["actor_kl_loss_coef", "actor_kl_loss_type"]) + self.get_configs("actor_use_kl_loss", "actor_use_uid") + self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type") - if st.session_state["algorithm_type"] == "opmd": - self._set_configs_with_st_columns( - ["actor_tau", "actor_opmd_baseline", "actor_use_uid"] - ) + self.get_configs("actor_tau", "actor_opmd_baseline") - self._set_actor_checkpoint() + self.get_configs("actor_checkpoint") - with critic_tab: - st.subheader("Critic Model Config") - self._set_configs_with_st_columns( - ["critic_ppo_micro_batch_size_per_gpu", "critic_ulysses_sequence_parallel_size"] - ) + def _expert_verl_critic_part(self): + st.subheader("Critic Model Config") + self.get_configs( + "critic_ppo_micro_batch_size_per_gpu", "critic_ulysses_sequence_parallel_size" + ) - self._set_configs_with_st_columns( - ["critic_lr", "critic_warmup_style", "critic_lr_warmup_steps_ratio"] - ) + self.get_configs("critic_lr", "critic_warmup_style", "critic_lr_warmup_steps_ratio") + + self.get_configs("critic_grad_clip", "critic_cliprange_value") + self.get_configs("critic_checkpoint") + + def _expert_verl_trainer_part(self): + name2func = { + "RL Training Config": self._expert_verl_training_part, + "RL Algorithm Config": self._expert_verl_algorithm_part, + "Actor and Ref Config": self._expert_verl_actor_part, + } + if use_critic(): + name2func["Critic Config"] = self._expert_verl_critic_part - self._set_configs_with_st_columns(["critic_grad_clip", "critic_cliprange_value"]) - self._set_critic_checkpoint() + tabs = st.tabs([name for name in name2func]) + for tab, func in zip(tabs, name2func.values()): + with tab: + func() def expert_mode(self): tab2func = { @@ -1455,7 +385,6 @@ def _generate_verl_config(self): }, "trainer": { "balance_batch": balance_batch, - "logger": ["tensorboard"], "resume_mode": st.session_state["resume_mode"], "resume_from_path": st.session_state["resume_from_path"], "default_hdfs_dir": st.session_state["default_hdfs_dir"], @@ -1467,7 +396,7 @@ def _generate_verl_config(self): }, } - if st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value: + if use_critic(): trainer_config["trainer"]["critic_warmup"] = st.session_state["critic_warmup"] trainer_config["critic"] = { "strategy": st.session_state["training_strategy"], @@ -1510,8 +439,8 @@ def _generate_verl_config(self): return trainer_config def _gen_buffer_config(self): + experience_buffer_path = st.session_state["experience_buffer_path"].strip() if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: - experience_buffer_path = st.session_state["experience_buffer_path"].strip() if ( not experience_buffer_path and st.session_state["storage_type"] == StorageType.SQL.value @@ -1527,7 +456,20 @@ def _gen_buffer_config(self): buffer_config = { "batch_size": st.session_state["train_batch_size"], "total_epochs": st.session_state["total_epochs"], - "explorer_input": { + "trainer_input": { + "experience_buffer": { + "name": "experience_buffer", + "storage_type": st.session_state["storage_type"], + "path": experience_buffer_path, + }, + "sft_warmup_steps": st.session_state["sft_warmup_steps"], + }, + "max_retry_times": st.session_state["buffer_max_retry_times"], + "max_retry_interval": st.session_state["max_retry_interval"], + } + + if st.session_state["mode"] != "train": + buffer_config["explorer_input"] = { "taskset": { "name": "taskset", "storage_type": StorageType.FILE.value, @@ -1548,31 +490,19 @@ def _gen_buffer_config(self): "default_reward_fn_type": st.session_state["default_reward_fn_type"], "system_prompt": st.session_state["system_prompt"], "reply_prefix": st.session_state["reply_prefix"], - }, - "trainer_input": { - "experience_buffer": { - "name": "experience_buffer", - "storage_type": st.session_state["storage_type"], - "path": experience_buffer_path, - }, - "sft_warmup_steps": st.session_state["sft_warmup_steps"], - }, - "max_retry_times": st.session_state["buffer_max_retry_times"], - "max_retry_interval": st.session_state["max_retry_interval"], - } - - for idx in range(st.session_state["_eval_tasksets_num"]): - if st.session_state[f"eval_taskset_{idx}_path"].strip(): - buffer_config["explorer_input"]["eval_tasksets"].append( - { - "name": st.session_state[f"eval_taskset_{idx}_name"], - "path": st.session_state[f"eval_taskset_{idx}_path"], - "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"], - "split": st.session_state[f"eval_taskset_{idx}_split"], - "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"], - "response_key": st.session_state[f"eval_taskset_{idx}_response_key"], - } - ) + } + for idx in range(st.session_state["_eval_tasksets_num"]): + if st.session_state[f"eval_taskset_{idx}_path"].strip(): + buffer_config["explorer_input"]["eval_tasksets"].append( + { + "name": st.session_state[f"eval_taskset_{idx}_name"], + "path": st.session_state[f"eval_taskset_{idx}_path"], + "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"], + "split": st.session_state[f"eval_taskset_{idx}_split"], + "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"], + "response_key": st.session_state[f"eval_taskset_{idx}_response_key"], + } + ) if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: experience_buffer = buffer_config["trainer_input"]["experience_buffer"] experience_buffer["split"] = st.session_state["dpo_dataset_train_split"] @@ -1676,7 +606,7 @@ def generate_config(self): "trainer": { "trainer_type": st.session_state["trainer_type"], "save_interval": st.session_state["save_interval"], - "enable_preview": True, # TODO + "enable_preview": st.session_state["enable_preview"], "actor_use_kl_loss": st.session_state["actor_use_kl_loss"], "actor_kl_loss_coef": st.session_state["actor_kl_loss_coef"], "actor_entropy_coef": st.session_state["actor_entropy_coef"], @@ -1694,7 +624,7 @@ def generate_config(self): }, } - if st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value: + if use_critic(): config["model"]["critic_model_path"] = ( st.session_state["critic_model_path"].strip() if st.session_state["critic_model_path"].strip() diff --git a/trinity/manager/config_registry/__init__.py b/trinity/manager/config_registry/__init__.py new file mode 100644 index 0000000000..e62c565fb4 --- /dev/null +++ b/trinity/manager/config_registry/__init__.py @@ -0,0 +1,13 @@ +import trinity.manager.config_registry.buffer_config_manager as buffer_config_manager +import trinity.manager.config_registry.explorer_config_manager as explorer_config_manager +import trinity.manager.config_registry.model_config_manager as model_config_manager +import trinity.manager.config_registry.trainer_config_manager as trainer_config_manager +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS + +__all__ = [ + "CONFIG_GENERATORS", + "buffer_config_manager", + "explorer_config_manager", + "model_config_manager", + "trainer_config_manager", +] diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py new file mode 100644 index 0000000000..044f982e94 --- /dev/null +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -0,0 +1,433 @@ +import streamlit as st + +from trinity.common.constants import AlgorithmType, PromptType, StorageType +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS +from trinity.common.workflows.workflow import WORKFLOWS +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS + + +@CONFIG_GENERATORS.register_config(default_value=20) +def set_total_epochs(**kwargs): + st.number_input("Total Epochs", min_value=1, **kwargs) + + +def _str_for_train_batch_size(): + trainer_gpu_num_str = ( + "`gpu_per_node * node_num - engine_num * tensor_parallel_size`" + if st.session_state["mode"] == "both" + else "`gpu_per_node * node_num`" + ) + return ( + f"Please ensure that `train_batch_size` can be divided by " + f"{trainer_gpu_num_str} = {st.session_state['trainer_gpu_num']}." + ) + + +@CONFIG_GENERATORS.register_config( + default_value=96, + visible=lambda: st.session_state["trainer_gpu_num"] > 0, + other_configs={"_train_batch_size_per_gpu": 16}, +) +def set_train_batch_size(**kwargs): + key = kwargs.get("key") + trainer_gpu_num = st.session_state["trainer_gpu_num"] + st.session_state[key] = ( + st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"] + ) + + def on_change(): + st.session_state["_train_batch_size_per_gpu"] = max( + st.session_state[key] // st.session_state["trainer_gpu_num"], 1 + ) + + st.number_input( + "Train Batch Size", + min_value=trainer_gpu_num, + step=trainer_gpu_num, + help=_str_for_train_batch_size(), + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_check() +def check_train_batch_size(unfinished_fields: set, key: str): + if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0: + unfinished_fields.add(key) + st.warning(_str_for_train_batch_size()) + + +@CONFIG_GENERATORS.register_config(default_value=3) +def set_buffer_max_retry_times(**kwargs): + st.number_input("Max Retry Times", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_max_retry_interval(**kwargs): + st.number_input("Max Retry Interval", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="") +def set_taskset_path(**kwargs): + st.text_input("Taskset Path", **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_taskset_path(unfinished_fields: set, key: str): + if not st.session_state[key].strip(): + unfinished_fields.add(key) + st.warning("Please input taskset path.") + + +# def _set_temperature(self): +# st.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) + +# def _set_top_p(self): +# st.number_input("Top-p", key="top_p", min_value=0.0, max_value=1.0) + +# def _set_top_k(self): +# st.number_input( +# "Top-k", +# key="top_k", +# min_value=-1, +# max_value=512, +# help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.", +# ) + +# def _set_logprobs(self): +# st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) + + +@CONFIG_GENERATORS.register_config( + visible=lambda: st.session_state["taskset_path"] + and "://" not in st.session_state["taskset_path"], + other_configs={ + "taskset_subset_name": None, + "taskset_split": "train", + "taskset_prompt_key": "question", + "taskset_response_key": "answer", + "temperature": 1.0, + "top_p": 1.0, # TODO: to be used + "top_k": -1, # TODO: to be used + "logprobs": 0, + }, +) +def set_taskset_args(**kwargs): + subset_name_col, split_col = st.columns(2) + subset_name_col.text_input( + "Subset Name :orange-badge[(Needs review)]", + key="taskset_subset_name", + help="The subset name used for `datasets.load_datasets`, see " + "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", + ) + split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split") + prompt_key_col, response_key_col = st.columns(2) + prompt_key_col.text_input("Prompt Key :orange-badge[(Needs review)]", key="taskset_prompt_key") + response_key_col.text_input( + "Response Key :orange-badge[(Needs review)]", key="taskset_response_key" + ) + # self._set_configs_with_st_columns(["temperature", "logprobs"]) + temperature_col, logprobs_col = st.columns(2) + temperature_col.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) + logprobs_col.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) + + +def _set_eval_taskset_idx(idx): + col1, col2 = st.columns([9, 1]) + col1.text_input( + "Taskset Name", + key=f"eval_taskset_{idx}_name", + ) + if col2.button("✖️", key=f"eval_taskset_{idx}_del_flag", type="primary"): + st.rerun() + st.text_input( + "Eval Taskset Path", + key=f"eval_taskset_{idx}_path", + ) + if not st.session_state[f"eval_taskset_{idx}_path"].strip(): + st.warning("Please input the taskset path, or it will be ignored.") + subset_name_col, split_col = st.columns(2) + subset_name_col.text_input( + "Subset Name :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_subset_name", + help="The subset name used for `datasets.load_datasets`, see " + "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", + ) + split_col.text_input( + "Eval Split :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_split", + ) + prompt_key_col, response_key_col = st.columns(2) + prompt_key_col.text_input( + "Prompt Key :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_prompt_key", + ) + response_key_col.text_input( + "Response Key :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_response_key", + ) + + temperature_col, logprobs_col, n_col = st.columns(3) + temperature_col.number_input( + "Temperature", + key=f"eval_taskset_{idx}_temperature", + min_value=0.0, + max_value=1.0, + ) + logprobs_col.number_input( + "Logprobs", + key=f"eval_taskset_{idx}_logprobs", + min_value=0, + max_value=20, + ) + n_col.number_input( + "Eval repeat times", + key=f"eval_taskset_{idx}_n", + min_value=1, + max_value=20, + ) + + +@CONFIG_GENERATORS.register_config(other_configs={"_eval_tasksets_num": 0}) +def set_eval_tasksets(**kwargs): + if st.button("Add Eval Taskset"): + idx = st.session_state["_eval_tasksets_num"] + st.session_state[f"eval_taskset_{idx}_split"] = "test" + st.session_state[f"eval_taskset_{idx}_prompt_key"] = "prompt" + st.session_state[f"eval_taskset_{idx}_response_key"] = "response" + st.session_state[f"eval_taskset_{idx}_temperature"] = 0.1 + st.session_state["_eval_tasksets_num"] += 1 + if st.session_state["_eval_tasksets_num"] > 0: + tabs = st.tabs( + [f"Eval Taskset {i + 1}" for i in range(st.session_state["_eval_tasksets_num"])] + ) + for idx, tab in enumerate(tabs): + with tab: + _set_eval_taskset_idx(idx) + + +@CONFIG_GENERATORS.register_config(default_value="math_workflow") +def set_default_workflow_type(**kwargs): + st.selectbox( + "Default Workflow Type :orange-badge[(Needs review)]", + WORKFLOWS.modules.keys(), + help=r"""`simple_workflow`: call 'model.chat()' to get responses. + +`math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses. + +Other workflows: conduct multi-turn task for the given dataset. +""", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value="math_reward") +def set_default_reward_fn_type(**kwargs): + st.selectbox( + "Default Reward Fn Type :orange-badge[(Needs review)]", + REWARD_FUNCTIONS.modules.keys(), + help=r"""`accuracy_reward`: check the accuracy for math problems. + +`format_reward`: check if the response matches the format (default: `** *`). + +`math_reward`: `accuracy_reward` (1 or 0) + `format_reward` (+0.1 or -0.1). +""", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_system_prompt(**kwargs): + st.text_area( + "System Prompt", + placeholder="System prompt is used to guide the model behavior.", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_reply_prefix(**kwargs): + st.text_area( + "Assistant Reply Prefix", + placeholder="""Assistant reply prefix is used to specify the initial content of model reply, """ + """and a common setting is: \nLet me solve this step by step. """, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=StorageType.QUEUE.value, + other_configs={ + "_dpo_storage_type": StorageType.FILE.value, + "_not_dpo_storage_type": StorageType.QUEUE.value, + }, +) +def set_storage_type(**kwargs): + key = kwargs.get("key") + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state[key] = st.session_state["_dpo_storage_type"] + storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] + else: + st.session_state[key] = st.session_state["_not_dpo_storage_type"] + storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value] + + def on_change(): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["_dpo_storage_type"] = st.session_state[key] + else: + st.session_state["_not_dpo_storage_type"] = st.session_state[key] + + st.selectbox( + "Storage Type", + storage_candidates, + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value="", + other_configs={ + "_dpo_experience_buffer_path": "", + "_not_dpo_experience_buffer_path": "", + }, +) +def set_experience_buffer_path(**kwargs): # TODO + key = kwargs.get("key") + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["taskset_path"] and not st.session_state["_dpo_experience_buffer_path"]: + st.session_state["_dpo_experience_buffer_path"] = st.session_state["taskset_path"] + st.session_state[key] = st.session_state["_dpo_experience_buffer_path"] + title = "DPO Dataset Path" + help_msg = r"""This path to DPO dataset, + +if `storage_type == StorageType.FILE`, this should be a path to a file, + +if `storage_type == StorageType.SQL`, this should be a path to database.""" + else: + st.session_state[key] = st.session_state["_not_dpo_experience_buffer_path"] + title = "Experience Buffer Path" + help_msg = r"""This path is used for `trainer`, + +if `storage_type == StorageType.QUEUE`, default to `None`, + +if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_root_dir, '.cache', project_name, experiment_name)}/data.db`.""" + + def on_change(): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["_dpo_experience_buffer_path"] = st.session_state[key] + else: + st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[key] + + st.text_input(title, help=help_msg, on_change=on_change, **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_experience_buffer_path(unfinished_fields: set, key: str): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if not st.session_state[key].strip(): + unfinished_fields.add(key) + st.warning("Please input DPO dataset path.") + + +@CONFIG_GENERATORS.register_config( + other_configs={ + "dpo_dataset_train_split": "train", + "dpo_dataset_prompt_type": PromptType.MESSAGES.value, + "dpo_dataset_prompt_key": "prompt", + "dpo_dataset_chosen_key": "chosen", + "dpo_dataset_rejected_key": "rejected", + } +) +def set_dpo_dataset_kwargs(**kwargs): + dpo_dataset_train_split_col, dpo_dataset_prompt_type_col = st.columns(2) + dpo_dataset_train_split_col.text_input( + "DPO Dataset Train Split :orange-badge[(Needs review)]", key="dpo_dataset_train_split" + ) + dpo_dataset_prompt_type_col.selectbox( + "DPO Dataset Prompt Type :orange-badge[(Needs review)]", + [prompt_type.value for prompt_type in PromptType], + key="dpo_dataset_prompt_type", + ) + + ( + dpo_dataset_prompt_key_col, + dpo_dataset_chosen_key_col, + dpo_dataset_rejected_key_col, + ) = st.columns(3) + dpo_dataset_prompt_key_col.text_input( + "DPO Dataset Prompt Key :orange-badge[(Needs review)]", key="dpo_dataset_prompt_key" + ) + dpo_dataset_chosen_key_col.text_input( + "DPO Dataset Chosen Key :orange-badge[(Needs review)]", key="dpo_dataset_chosen_key" + ) + dpo_dataset_rejected_key_col.text_input( + "DPO Dataset Rejected Key :orange-badge[(Needs review)]", + key="dpo_dataset_rejected_key", + ) + + +@CONFIG_GENERATORS.register_config(default_value="") +def set_sft_warmup_dataset_path(**kwargs): + st.text_input("SFT Warmup Dataset Path", **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_sft_warmup_dataset_path(unfinished_fields: set, key: str): + if st.session_state["sft_warmup_steps"]: + if not st.session_state[key].strip(): + unfinished_fields.add(key) + st.warning("Please input SFT warmup dataset path when `sft_warmup_steps` is not 0") + + +@CONFIG_GENERATORS.register_config( + visible=lambda: st.session_state["sft_warmup_dataset_path"] + and "://" not in st.session_state["sft_warmup_dataset_path"], + other_configs={ + "sft_warmup_train_split": "train", + "sft_warmup_prompt_type": PromptType.MESSAGES.value, + "sft_warmup_messages_key": "messages", + "sft_warmup_prompt_key": "prompt", + "sft_warmup_response_key": "response", + }, +) +def set_sft_warmup_dataset_args(**kwargs): + ( + sft_warmup_train_split_col, + sft_warmup_prompt_type_col, + ) = st.columns(2) + sft_warmup_train_split_col.text_input( + "SFT Dataset Train Split :orange-badge[(Needs review)]", + key="sft_warmup_train_split", + ) + sft_warmup_prompt_type_col.selectbox( + "SFT Dataset Prompt Type :orange-badge[(Needs review)]", + [prompt_type.value for prompt_type in PromptType], + key="sft_warmup_prompt_type", + ) + ( + sft_warmup_messages_key_col, + sft_warmup_prompt_key_col, + sft_warmup_response_key_col, + ) = st.columns( + 3 + ) # TODO: select by prompt type + sft_warmup_messages_key_col.text_input( + "SFT Dataset Messages Key :orange-badge[(Needs review)]", + key="sft_warmup_messages_key", + ) + sft_warmup_prompt_key_col.text_input( + "SFT Dataset Prompt Key :orange-badge[(Needs review)]", key="sft_warmup_prompt_key" + ) + sft_warmup_response_key_col.text_input( + "SFT Dataset Response Key :orange-badge[(Needs review)]", + key="sft_warmup_response_key", + ) + + +# TODO: read_experience_strategy + + +@CONFIG_GENERATORS.register_config(default_value=0) +def set_sft_warmup_steps(**kwargs): + st.number_input("SFT Warmup Steps", min_value=0, **kwargs) diff --git a/trinity/manager/config_registry/config_registry.py b/trinity/manager/config_registry/config_registry.py new file mode 100644 index 0000000000..3b621a2de2 --- /dev/null +++ b/trinity/manager/config_registry/config_registry.py @@ -0,0 +1,209 @@ +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Set + +import streamlit as st + +from trinity.utils.registry import Registry + + +class ConfigRegistry(Registry): + """ + A registry for managing configuration settings and their associated functions. + """ + + def __init__(self, name: str): + super().__init__(name) + self._default_config = {} # Stores default values for configs + self._config_visibles = {} # Stores visibles for config visibility + self.unfinished_fields = set() + + def set_unfinished_fields(self, unfinished_fields: set): + """ + Set the unfinished fields to track incomplete configurations. + + Args: + unfinished_fields (set): Set of field names that are not yet configured. + """ + self.unfinished_fields = unfinished_fields + + @property + def default_config(self) -> dict: + """ + Get the dictionary of default configuration values. + """ + return self._default_config + + def get(self, config_name: str): + """ + Retrieve a configuration function if its visible is met (if any). + + Args: + config_name (str): Name of the configuration to retrieve. + + Returns: + The configuration function if visibles are met, else None. + """ + if config_name in self._config_visibles: + if not self._config_visibles[config_name](): + return None + return super().get(config_name) + + def get_check_func(self, config_name: str): + """ + Get the check function associated with a configuration. + + Args: + config_name (str): Name of the configuration. + + Returns: + The check function for the specified configuration. + """ + check_func_name = f"check_{config_name}" + return super().get(check_func_name) + + def get_configs(self, *config_names: str, columns_spec: List[int] = None): + """ + Retrieve and display multiple configurations in Streamlit columns. + + Args: + *config_names (str): Names of configurations to retrieve. + columns_spec (List[int], optional): Configuration for Streamlit columns. + """ + config_pair = [] + for config_name in config_names: + config_func = self.get(config_name) + if config_func is not None: + config_pair.append((config_name, config_func)) + if len(config_pair) == 0: + return + + if columns_spec is None: + columns_spec = len(config_pair) + columns = st.columns(columns_spec) + for col, (_, config_func) in zip(columns, config_pair): + with col: + config_func() + for config_name, _ in config_pair: + check_func = self.get_check_func(config_name) + if check_func is not None: + check_func(unfinished_fields=self.unfinished_fields) + + def _register_config( + self, + config_name: str, + config_func: Callable[[None], None], + default_value: Optional[Any] = None, + visible: Optional[Callable[[], bool]] = None, + other_configs: Optional[Dict[str, Any]] = None, + ): + """ + Internal method to register a configuration and its associated function. + + Args: + config_name (str): Name of the configuration. + config_func (Callable): Function to set the configuration. + default_value (Any, optional): Default value for the configuration. + visible (Callable, optional): visible for when the config should be visible/applicable. + other_configs (Dict[str, Any], optional): Additional configurations to register. + """ + assert config_name not in self._default_config, f"{config_name} already exists." + self._default_config[config_name] = default_value + if visible is not None: + self._config_visibles[config_name] = visible + if other_configs is not None: + for name, value in other_configs.items(): + assert name not in self._default_config, f"{name} already exists." + self._default_config[name] = value + super()._register_module(module_name=config_name, module_cls=config_func) + + def register_config( + self, + default_value: Optional[Any] = None, + config_func: Optional[Callable[[None], None]] = None, + visible: Optional[Callable[[], bool]] = None, + other_configs: Optional[Dict[str, Any]] = None, + ): + """ + Decorator to register a configuration function. + + The function name must start with 'set_', and the part after 'set_' becomes the config name. + + Note: This function will automatically pass `key=config_name` as an argument to the + registered configuration function. Ensure your function accepts this keyword argument. + + Args: + default_value (Any, optional): Default value for the configuration. + config_func (Callable, optional): The configuration function to register. + visible (Callable, optional): visible for when the config should be visible. + other_configs (Dict[str, Any], optional): Additional configurations to register. + + Returns: + A decorator function if config_func is None, else the registered config function. + """ + + # if config_func is None, should return a decorator function + def _register(config_func: Callable[[None], None]): + config_name = config_func.__name__ + prefix = "set_" + assert config_name.startswith( + prefix + ), f"Config function name should start with `{prefix}`, got {config_name}" + config_name = config_name[len(prefix) :] + config_func = partial(config_func, key=config_name) + self._register_config( + config_name=config_name, + config_func=config_func, + default_value=default_value, + visible=visible, + other_configs=other_configs, + ) + return config_func + + if config_func is not None: + return _register(config_func) + return _register + + def _register_check(self, config_name: str, check_func: Callable[[Set, str], None]): + """ + Internal method to register a check function for a configuration. + + Args: + config_name (str): Name of the configuration to check. + check_func (Callable): Function to check the configuration. + """ + assert config_name in self._default_config, f"`{config_name}` is not registered." + super()._register_module(module_name=f"check_{config_name}", module_cls=check_func) + + def register_check(self, check_func: Callable[[Set, str], None] = None): + """ + Decorator to register a check function for a configuration. + + The function name must start with 'check_', and the part after 'check_' should match a config name. + + Note: This function will automatically pass `key=config_name` and `unfinished_fields=self.unfinished_fields` as an argument to the registered check function. Ensure your function accepts these keyword arguments. + + Args: + check_func (Callable, optional): The check function to register. + + Returns: + A decorator function if check_func is None, else the registered check function. + """ + + def _register(check_func: Callable[[Set, str], None]): + config_name = check_func.__name__ + prefix = "check_" + assert config_name.startswith( + prefix + ), f"Check function name must start with `{prefix}`, got {config_name}" + config_name = config_name[len(prefix) :] + check_func = partial(check_func, key=config_name) + self._register_check(config_name, check_func) + return check_func + + if check_func is not None: + return _register(check_func) + return _register + + +# Global registry for configuration generators +CONFIG_GENERATORS = ConfigRegistry("config_generators") diff --git a/trinity/manager/config_registry/explorer_config_manager.py b/trinity/manager/config_registry/explorer_config_manager.py new file mode 100644 index 0000000000..9393187f60 --- /dev/null +++ b/trinity/manager/config_registry/explorer_config_manager.py @@ -0,0 +1,298 @@ +import streamlit as st + +from trinity.common.constants import AlgorithmType, SyncMethod +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num + + +def explorer_visible() -> bool: + return st.session_state["mode"] == "both" + + +@CONFIG_GENERATORS.register_config(default_value=32, visible=explorer_visible) +def set_runner_num(**kwargs): + st.number_input("Runner Num", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=900, visible=explorer_visible) +def set_max_timeout(**kwargs): + st.number_input("Max Timeout", min_value=0, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=2, visible=explorer_visible) +def set_explorer_max_retry_times(**kwargs): + st.number_input("Explorer Max Retry Times", min_value=0, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1000, visible=explorer_visible) +def set_eval_interval(**kwargs): + st.number_input("Eval Interval", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) +def set_eval_on_latest_checkpoint(**kwargs): + st.checkbox("Eval on Latest Checkpoint", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="vllm_async", visible=explorer_visible) +def set_engine_type(**kwargs): + st.selectbox("Engine Type", ["vllm_async", "vllm"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=2, visible=explorer_visible) +def set_engine_num(**kwargs): + key = kwargs.get("key") + total_gpu_num = st.session_state["total_gpu_num"] + max_engine_num = (total_gpu_num - 1) // st.session_state["tensor_parallel_size"] + if st.session_state[key] > max_engine_num: + st.session_state[key] = max_engine_num + set_trainer_gpu_num() + st.number_input( + "Engine Num", + min_value=1, + max_value=max_engine_num, + on_change=set_trainer_gpu_num, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1, visible=explorer_visible) +def set_tensor_parallel_size(**kwargs): + key = kwargs.get("key") + total_gpu_num = st.session_state["total_gpu_num"] + max_tensor_parallel_size = (total_gpu_num - 1) // st.session_state["engine_num"] + if st.session_state[key] > max_tensor_parallel_size: + st.session_state[key] = max_tensor_parallel_size + set_trainer_gpu_num() + st.number_input( + "Tensor Parallel Size", + min_value=1, + max_value=max_tensor_parallel_size, + on_change=set_trainer_gpu_num, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_check() +def check_tensor_parallel_size(unfinished_fields: set, key: str): + if st.session_state["trainer_gpu_num"] <= 0: + unfinished_fields.add("engine_num") + unfinished_fields.add("tensor_parallel_size") + st.warning( + "Please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that at least one GPU is reserved for the `trainer`." + ) + elif ( + st.session_state["node_num"] > 1 + and st.session_state["trainer_gpu_num"] % st.session_state["gpu_per_node"] != 0 + ): + unfinished_fields.add("engine_num") + unfinished_fields.add("tensor_parallel_size") + st.warning( + "When `node_num > 1`, please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that the number of GPUs reserved for the `trainer` is divisible by `gpu_per_node`" + ) + + +@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) +def set_use_v1(**kwargs): + st.checkbox("Use V1 Engine", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) +def set_enforce_eager(**kwargs): + st.checkbox("Enforce Eager", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) +def set_enable_prefix_caching(**kwargs): + st.checkbox("Prefix Caching", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) +def set_enable_chunked_prefill(**kwargs): + st.checkbox("Chunked Prefill", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=0.9, visible=explorer_visible) +def set_gpu_memory_utilization(**kwargs): + st.number_input("GPU Memory Utilization", min_value=0.0, max_value=1.0, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="bfloat16", visible=explorer_visible) +def set_dtype(**kwargs): + st.selectbox("Dtype", ["bfloat16", "float16", "float32"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=42, visible=explorer_visible) +def set_seed(**kwargs): + st.number_input("Seed", step=1, **kwargs) + + +# TODO: max_prompt_tokens +# TODO: max_response_tokens +# TODO: chat_template + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) +def set_enable_thinking(**kwargs): + st.checkbox("Enable Thinking For Qwen3", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) +def set_enable_openai_api(**kwargs): + st.checkbox("Enable OpenAI API", **kwargs) + + +def _set_auxiliary_model_idx(idx): + col1, col2 = st.columns([9, 1]) + col1.text_input( + "Model Path", + key=f"auxiliary_model_{idx}_model_path", + ) + if col2.button("✖️", key=f"auxiliary_model_{idx}_del_flag", type="primary"): + st.rerun() + + engine_type_col, engine_num_col, tensor_parallel_size_col = st.columns(3) + total_gpu_num = st.session_state["total_gpu_num"] + engine_type_col.selectbox( + "Engine Type", ["vllm_async"], key=f"auxiliary_model_{idx}_engine_type" + ) + engine_num_col.number_input( + "Engine Num", + min_value=1, + max_value=total_gpu_num - 1, + on_change=set_trainer_gpu_num, + key=f"auxiliary_model_{idx}_engine_num", + ) + tensor_parallel_size_col.number_input( + "Tensor Parallel Size", + min_value=1, + max_value=8, + on_change=set_trainer_gpu_num, + key=f"auxiliary_model_{idx}_tensor_parallel_size", + ) + + gpu_memory_utilization_col, dtype_col, seed_col = st.columns(3) + gpu_memory_utilization_col.number_input( + "GPU Memory Utilization", + min_value=0.0, + max_value=1.0, + key=f"auxiliary_model_{idx}_gpu_memory_utilization", + ) + dtype_col.selectbox( + "Dtype", ["bfloat16", "float16", "float32"], key=f"auxiliary_model_{idx}_dtype" + ) + seed_col.number_input("Seed", step=1, key=f"auxiliary_model_{idx}_seed") + + ( + use_v1_col, + enforce_eager_col, + enable_prefix_caching_col, + enable_chunked_prefill_col, + ) = st.columns(4) + use_v1_col.checkbox("Use V1 Engine", key=f"auxiliary_model_{idx}_use_v1") + enforce_eager_col.checkbox("Enforce Eager", key=f"auxiliary_model_{idx}_enforce_eager") + enable_prefix_caching_col.checkbox( + "Prefix Caching", key=f"auxiliary_model_{idx}_enable_prefix_caching" + ) + enable_chunked_prefill_col.checkbox( + "Chunked Prefill", key=f"auxiliary_model_{idx}_enable_chunked_prefill" + ) + + enable_thinking_col, enable_openai_api = st.columns(2) + enable_thinking_col.checkbox( + "Enable Thinking For Qwen3", key=f"auxiliary_model_{idx}_enable_thinking" + ) + enable_openai_api.checkbox("Enable OpenAI API", key=f"auxiliary_model_{idx}_enable_openai_api") + + +@CONFIG_GENERATORS.register_config(other_configs={"_auxiliary_models_num": 0}) +def set_auxiliary_models(**kwargs): + if st.button("Add Auxiliary Models"): + idx = st.session_state["_auxiliary_models_num"] + st.session_state[f"auxiliary_model_{idx}_engine_num"] = 1 + st.session_state[f"auxiliary_model_{idx}_tensor_parallel_size"] = 1 + st.session_state[f"auxiliary_model_{idx}_gpu_memory_utilization"] = 0.9 + st.session_state[f"auxiliary_model_{idx}_seed"] = 42 + st.session_state[f"auxiliary_model_{idx}_use_v1"] = True + st.session_state[f"auxiliary_model_{idx}_enforce_eager"] = True + st.session_state["_auxiliary_models_num"] += 1 + set_trainer_gpu_num() + if st.session_state["_auxiliary_models_num"] > 0: + tabs = st.tabs( + [f"Auxiliary Model {i + 1}" for i in range(st.session_state["_auxiliary_models_num"])] + ) + for idx, tab in enumerate(tabs): + with tab: + _set_auxiliary_model_idx(idx) + + +@CONFIG_GENERATORS.register_check() +def check_auxiliary_models(unfinished_fields: set, key: str): + if st.session_state["trainer_gpu_num"] <= 0: + unfinished_fields.add("engine_num") + unfinished_fields.add("tensor_parallel_size") + st.warning( + "Please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that at least one GPU is reserved for the `trainer`." + ) + elif ( + st.session_state["node_num"] > 1 + and st.session_state["trainer_gpu_num"] % st.session_state["gpu_per_node"] != 0 + ): + unfinished_fields.add("engine_num") + unfinished_fields.add("tensor_parallel_size") + st.warning( + "When `node_num > 1`, please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that the number of GPUs reserved for the `trainer` is divisible by `gpu_per_node`" + ) + + +# Synchronizer Configs + + +@CONFIG_GENERATORS.register_config( + default_value=SyncMethod.NCCL.value, + visible=explorer_visible, + other_configs={"_not_dpo_sync_method": SyncMethod.NCCL.value}, +) +def set_sync_method(**kwargs): + key = kwargs.get("key") + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state[key] = SyncMethod.CHECKPOINT.value + disabled = True + else: + st.session_state[key] = st.session_state["_not_dpo_sync_method"] + disabled = False + + def on_change(): + if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + st.session_state["_not_dpo_sync_method"] = st.session_state[key] + + st.selectbox( + "Sync Method", + [sync_method.value for sync_method in SyncMethod], + help="""`nccl`: the explorer and trainer sync model weights once every `sync_interval` steps. + +`checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_interval`.""", + disabled=disabled, + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=10, visible=explorer_visible) +def set_sync_interval(**kwargs): + st.number_input( + "Sync Interval", + min_value=1, + help="""The step interval at which the `explorer` and `trainer` synchronize model weight.""", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1200, visible=explorer_visible) +def set_sync_timeout(**kwargs): + st.number_input( + "Sync Timeout", + min_value=1, + help="The timeout value for the synchronization operation.", + **kwargs, + ) diff --git a/trinity/manager/config_registry/model_config_manager.py b/trinity/manager/config_registry/model_config_manager.py new file mode 100644 index 0000000000..837bf27679 --- /dev/null +++ b/trinity/manager/config_registry/model_config_manager.py @@ -0,0 +1,206 @@ +import os + +import streamlit as st + +from trinity.common.constants import AlgorithmType, MonitorType +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +from trinity.manager.config_registry.trainer_config_manager import use_critic +from trinity.trainer.verl.ray_trainer import AdvantageEstimator + + +def set_total_gpu_num(): + st.session_state["total_gpu_num"] = ( + st.session_state["gpu_per_node"] * st.session_state["node_num"] + ) + set_trainer_gpu_num() + + +def set_trainer_gpu_num(): + if st.session_state["mode"] == "both": + trainer_gpu_num = ( + st.session_state["total_gpu_num"] + - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"] + ) + for idx in range(st.session_state["_auxiliary_models_num"]): + engine_num = st.session_state[f"auxiliary_model_{idx}_engine_num"] + tensor_parallel_size = st.session_state[f"auxiliary_model_{idx}_tensor_parallel_size"] + trainer_gpu_num -= engine_num * tensor_parallel_size + st.session_state["trainer_gpu_num"] = trainer_gpu_num + else: # model == train + st.session_state["trainer_gpu_num"] = st.session_state["total_gpu_num"] + + +@CONFIG_GENERATORS.register_config(default_value="Trinity-RFT") +def set_project(**kwargs): + st.text_input("Project", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="qwen2.5-1.5B") +def set_exp_name(**kwargs): + st.text_input("Experiment Name", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="") +def set_checkpoint_root_dir(**kwargs): + st.text_input("Checkpoint Root Dir", **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_checkpoint_root_dir(unfinished_fields: set, key: str): + if not st.session_state[key].strip(): # TODO: may auto generate + unfinished_fields.add(key) + st.warning("Please input checkpoint root dir.") + elif not os.path.isabs(st.session_state[key].strip()): + unfinished_fields.add("checkpoint_root_dir") + st.warning("Please input an absolute path.") + + +@CONFIG_GENERATORS.register_config(default_value=MonitorType.TENSORBOARD.value) +def set_monitor_type(**kwargs): + st.selectbox( + "Monitor Type", + options=[monitor_type.value for monitor_type in MonitorType], + **kwargs, + ) + + +# Algorithm Configs + + +@CONFIG_GENERATORS.register_config( + default_value=AlgorithmType.PPO.value, + other_configs={"mode": "both", "adv_estimator": AdvantageEstimator.GAE.value}, +) +def set_algorithm_type(**kwargs): + def on_change(): + if st.session_state["algorithm_type"] == AlgorithmType.PPO.value: + st.session_state["mode"] = "both" + st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value + elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value: + st.session_state["mode"] = "both" + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["mode"] = "train" + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value: + st.session_state["mode"] = "both" + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + else: # TODO: add more algorithms + pass + set_trainer_gpu_num() + + st.selectbox( + "Algorithm Type", + [ + AlgorithmType.PPO.value, + AlgorithmType.GRPO.value, + AlgorithmType.DPO.value, + AlgorithmType.OPMD.value, + ], + key="algorithm_type", + on_change=on_change, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=1, + visible=lambda: st.session_state["mode"] == "both", + other_configs={ + "_grouped_adv_repeat_times": 2, + "_not_grouped_adv_repeat_times": 1, + }, +) +def set_repeat_times(**kwargs): # TODO + key = kwargs.get("key") + grouped_adv_algorithms = [ + AlgorithmType.GRPO.value, + AlgorithmType.OPMD.value, # TODO: may add rloo + ] + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + min_repeat_times = 2 + st.session_state[key] = st.session_state["_grouped_adv_repeat_times"] + else: + min_repeat_times = 1 + st.session_state[key] = st.session_state["_not_grouped_adv_repeat_times"] + + def on_change(): + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + st.session_state["_grouped_adv_repeat_times"] = st.session_state[key] + else: + st.session_state["_not_grouped_adv_repeat_times"] = st.session_state[key] + + st.number_input( + "Repeat Times", + min_value=min_repeat_times, + help="`repeat_times` is used to set how many experiences each task can generate, " + "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1.0) +def set_gamma(**kwargs): + st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1.0) +def set_lam(**kwargs): + st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs) + + +# Model Configs + + +@CONFIG_GENERATORS.register_config(default_value="") +def set_model_path(**kwargs): + st.text_input("Model Path", **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_model_path(unfinished_fields: set, key: str): + if not st.session_state[key].strip(): + unfinished_fields.add(key) + st.warning("Please input model path.") + + +@CONFIG_GENERATORS.register_config( + default_value="", + visible=use_critic, +) +def set_critic_model_path(**kwargs): + st.text_input( + "Critic Model Path (defaults to `model_path`)", + key="critic_model_path", + ) + + +@CONFIG_GENERATORS.register_config(default_value=1024) +def set_max_prompt_tokens(**kwargs): + st.number_input("Max Prompt Tokens", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1024) +def set_max_response_tokens(**kwargs): + st.number_input("Max Response Tokens", min_value=1, **kwargs) + + +# Cluster Config + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_node_num(**kwargs): + st.number_input("Node Num", min_value=1, on_change=set_total_gpu_num, **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=8, other_configs={"total_gpu_num": 8, "trainer_gpu_num": 6} +) +def set_gpu_per_node(**kwargs): + st.number_input( + "GPU Per Node", + min_value=1, + max_value=8, + on_change=set_total_gpu_num, + **kwargs, + ) diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py new file mode 100644 index 0000000000..d0f5d26897 --- /dev/null +++ b/trinity/manager/config_registry/trainer_config_manager.py @@ -0,0 +1,450 @@ +import streamlit as st + +from trinity.common.constants import AlgorithmType, SyncMethod +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +from trinity.trainer.verl.ray_trainer import AdvantageEstimator + + +def use_critic(): + return st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value + + +@CONFIG_GENERATORS.register_config(default_value="verl") +def set_trainer_type(**kwargs): + st.selectbox("Trainer Type", ["verl"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=100, other_configs={"_nccl_save_interval": 100}) +def set_save_interval(**kwargs): + key = kwargs.get("key") + if ( + st.session_state["algorithm_type"] == AlgorithmType.DPO.value + or st.session_state["sync_method"] == SyncMethod.NCCL.value + ): + st.session_state[key] = st.session_state["_nccl_save_interval"] + freeze_save_interval = False + else: + st.session_state[key] = st.session_state["sync_interval"] + freeze_save_interval = True + + def on_change(): + if ( + st.session_state["algorithm_type"] == AlgorithmType.DPO.value + or st.session_state["sync_method"] == SyncMethod.NCCL.value + ): + st.session_state["_nccl_save_interval"] = st.session_state[key] + + st.number_input( + "Save Interval", + min_value=1, + help="Set to `sync_interval` when `algorithm_type != DPO && sync_method == checkpoint`", + disabled=freeze_save_interval, + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=True) +def set_enable_preview(**kwargs): + st.checkbox("Enable Preview", **kwargs) + + +def _actor_use_kl_loss_visible(): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["actor_use_kl_loss"] = True + return False + return True + + +@CONFIG_GENERATORS.register_config( + default_value=True, + visible=_actor_use_kl_loss_visible, + other_configs={"_not_dpo_actor_use_kl_loss": True}, +) +def set_actor_use_kl_loss(**kwargs): + key = kwargs.get("key") + st.session_state[key] = st.session_state["_not_dpo_actor_use_kl_loss"] + + def on_change(): + st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[key] + + st.checkbox("Use KL Loss", on_change=on_change, **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"] +) +def set_actor_kl_loss_coef(**kwargs): + st.number_input( + r"KL Loss Coef :blue-badge[$\beta$]", + min_value=0.0, + max_value=1.0, + format="%.1e", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"] +) +def set_actor_entropy_coef(**kwargs): + st.number_input( + "Entropy Coeff", + min_value=0.0, + max_value=1.0, + format="%.1e", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1.0) +def set_actor_grad_clip(**kwargs): + st.number_input( + "Grad Clip :blue-badge[(Actor)]", + min_value=0.0, + max_value=1.0, + help="Clipping by Norm", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=0.2) +def set_actor_clip_ratio(**kwargs): + st.number_input( + r"Clip Ratio :blue-badge[$\epsilon$]", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +# veRL Trainer Configs + + +@CONFIG_GENERATORS.register_config( + default_value=[ + "balance_batch", + "gradient_checkpointing", + "remove_padding", + "dynamic_bsz", + ] +) +def set_training_args(**kwargs): + st.multiselect( + "Training Args", + [ + "balance_batch", + "gradient_checkpointing", + "remove_padding", + "dynamic_bsz", + ], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_ppo_epochs(**kwargs): + st.number_input("PPO Epochs", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="fsdp") +def set_training_strategy(**kwargs): + st.selectbox( + "Training Strategy", + ["fsdp", "megatron"], + help="megatron is not tested", + **kwargs, + ) + + +def use_fsdp(): + return st.session_state["training_strategy"] == "fsdp" + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp) +def set_param_offload(**kwargs): + st.checkbox("FSDP Param Offload", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp) +def set_optimizer_offload(**kwargs): + st.checkbox("FSDP Optimizer Offload", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="auto") +def set_resume_mode(**kwargs): + st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value="", visible=lambda: st.session_state["resume_mode"] == "resume_path" +) +def set_resume_from_path(**kwargs): + st.text_input("Resume Path", **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_resume_from_path(unfinished_fields: set, key: str): + if st.session_state["resume_mode"] == "resume_path" and ( + not st.session_state[key].strip() or "global_step_" not in st.session_state[key] + ): + unfinished_fields.add(key) + st.warning("Please input a valid resume path when `resume_mode == resume_path`") + + +@CONFIG_GENERATORS.register_config(default_value=0) +def set_critic_warmup(**kwargs): + st.number_input("Critic Warmup Steps", min_value=0, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_total_training_steps(**kwargs): + st.number_input("Total Training Steps", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_default_hdfs_dir(**kwargs): + st.text_input("Default HDFS Dir", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False) +def set_remove_previous_ckpt_in_save(**kwargs): + st.checkbox("Remove Previous Checkpoint in Save", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False) +def set_del_local_ckpt_after_load(**kwargs): + st.checkbox("Delete Local Checkpoint After Load", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_max_actor_ckpt_to_keep(**kwargs): + st.number_input("Max Actor Checkpoint to Keep", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_max_critic_ckpt_to_keep(**kwargs): + st.number_input("Max Critic Checkpoint to Keep", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=True) +def set_norm_adv_by_std_in_grpo(**kwargs): + st.checkbox("Norm Adv by Std in GRPO", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False) +def set_use_kl_in_reward(**kwargs): + st.checkbox("Use KL in Reward", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="low_var_kl") +def set_kl_penalty(**kwargs): + st.selectbox("KL Penalty", ["kl", "abs", "mse", "low_var_kl"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="fixed") +def set_kl_ctrl_type(**kwargs): + st.selectbox("KL Ctrl Type", ["fixed", "adaptive"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=0.001) +def set_kl_ctrl_coef(**kwargs): + st.number_input("KL Ctrl Coef", format="%.1e", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=10000) +def set_horizon(**kwargs): + st.number_input("Horizon", min_value=1.0, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=0.1) +def set_target_kl(**kwargs): + st.number_input("Target KL", format="%.1e", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=4) +def set_actor_ppo_micro_batch_size_per_gpu(**kwargs): + key = kwargs.get("key") + max_value = st.session_state["_train_batch_size_per_gpu"] + st.session_state[key] = min(st.session_state[key], max_value) + st.number_input( + "Micro Batch Size Per GPU :blue-badge[(Actor)]", min_value=1, max_value=max_value, **kwargs + ) + + +@CONFIG_GENERATORS.register_config(default_value=8) +def set_ref_log_prob_micro_batch_size_per_gpu(**kwargs): + key = kwargs.get("key") + max_value = st.session_state["_train_batch_size_per_gpu"] + st.session_state[key] = min(st.session_state[key], max_value) + st.number_input( + "Micro Batch Size Per GPU :blue-badge[(Ref)]", min_value=1, max_value=max_value, **kwargs + ) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_actor_ulysses_sequence_parallel_size(**kwargs): + st.number_input( + "Ulysses Sequence Parallel Size", + min_value=1, + max_value=8, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1e-6) +def set_actor_lr(**kwargs): + st.number_input( + "Learning Rate :blue-badge[(Actor)]", + min_value=1e-7, + max_value=1e-3, + format="%.1e", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value="constant") +def set_actor_warmup_style(**kwargs): + st.selectbox( + "LR Warmup Style :blue-badge[(Actor)]", + ["constant", "cosine"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=0.0) +def set_actor_lr_warmup_steps_ratio(**kwargs): + st.number_input( + "LR Warmup Steps Ratio :blue-badge[(Actor)]", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=0.0, visible=lambda: st.session_state["algorithm_type"] == "opmd" +) +def set_actor_tau(**kwargs): + st.number_input("Tau for OPMD", min_value=0.0, format="%.1e", **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value="mean", visible=lambda: st.session_state["algorithm_type"] == "opmd" +) +def set_actor_opmd_baseline(**kwargs): + st.selectbox( + "OPMD Baseline", + ["mean", "logavgexp"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=False, visible=lambda: st.session_state["algorithm_type"] == "opmd" +) +def set_actor_use_uid(**kwargs): + st.checkbox("Use UID for OPMD", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="low_var_kl") +def set_actor_kl_loss_type(**kwargs): + st.selectbox( + "KL Loss Type", + ["kl", "abs", "mse", "low_var_kl"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=["model", "hf_model", "optimizer", "extra"]) +def set_actor_checkpoint(**kwargs): + st.multiselect( + "Checkpoint", + ["model", "hf_model", "optimizer", "extra"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1e-6, visible=use_critic) +def set_critic_lr(**kwargs): + st.number_input( + "Learning Rate :blue-badge[(Critic)]", + min_value=1e-7, + max_value=1e-3, + format="%.1e", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value="constant", visible=use_critic) +def set_critic_warmup_style(**kwargs): + st.selectbox( + "LR Warmup Style :blue-badge[(Critic)]", + ["constant", "cosine"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=0.0, visible=use_critic) +def set_critic_lr_warmup_steps_ratio(**kwargs): + st.number_input( + "LR Warmup Steps Ratio :blue-badge[(Critic)]", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1.0, visible=use_critic) +def set_critic_grad_clip(**kwargs): + st.number_input( + "Grad Clip :blue-badge[(Critic)]", + min_value=0.0, + max_value=1.0, + help="Clipping by Norm", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=0.5, visible=use_critic) +def set_critic_cliprange_value(**kwargs): + st.number_input( + "Cliprange Value", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=8, visible=use_critic) +def set_critic_ppo_micro_batch_size_per_gpu(**kwargs): + key = kwargs.get("key") + max_value = st.session_state["_train_batch_size_per_gpu"] + st.session_state[key] = min(st.session_state[key], max_value) + st.number_input( + "Micro Batch Size Per GPU :blue-badge[(Critic)]", + min_value=1, + max_value=max_value, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1, visible=use_critic) +def set_critic_ulysses_sequence_parallel_size(**kwargs): + st.number_input( + "Ulysses Sequence Parallel Size", + min_value=1, + max_value=8, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=["model", "optimizer", "extra"], visible=use_critic +) +def set_critic_checkpoint(**kwargs): + st.multiselect( + "Checkpoint", + ["model", "hf_model", "optimizer", "extra"], + **kwargs, + )