diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 09377e1f66..8cb8856fbc 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -223,6 +223,8 @@ The configuration for each task dataset is defined as follows: - `temperature`: The temperature for sampling. - `default_workflow_type`: Type of workflow logic applied to this dataset. If not specified, the `buffer.default_workflow_type` is used. - `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used. +- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters. + ### Trainer Input diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 2e4daeab0b..1b8e3fc56b 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -45,6 +45,15 @@ To handle differences in `Task` contents, Trinity-RFT provides a unified `Task` - **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields. - **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`. - **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`. + - **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field. + +```{tip} +`workflow`, `workflow_args` and `raw_task` provide different levels of customization. + +- `workflow` provides the global settings for all tasks that uses the same workflow. (Global Level) +- `workflow_args` can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level) +- `raw_task` provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level) +``` In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example: @@ -111,7 +120,7 @@ During initialization, `Workflow` receives the following parameters: You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. ``` -Here’s an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization. +Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization. ```python class ExampleWorkflow(Workflow): @@ -188,6 +197,25 @@ class ExampleWorkflow(Workflow): pass ``` +For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`: + +```python +# existing import lines +from .example_workflow import ExampleWorkflow + +__all__ = [ + # existing __all__ lines + "ExampleWorkflow", +] +``` + +For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in `trinity/plugins`. Trinity-RFT will automatically detect and load all custom modules in this folder. + +```{tip} +You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `/trinity/plugins` as the default directory. +``` + + #### Avoid Re-initialization For heavy workflows, re-initializing every time can incurs extra computational costs. @@ -286,6 +314,126 @@ trinity run --config --- +## Adding New Config Entries for the Config Generator (Advanced) + +### Step 0: Understanding Streamlit + +Before adding new parameters to the Config Generator page, it is essential to familiarize yourself with the relevant API and mechanisms of [Streamlit](https://docs.streamlit.io/develop/api-reference). This project primarily utilizes various input components from Streamlit and employs `st.session_state` to store user-input parameters. + +### Step 1: Implement New Config Entries + +To illustrate the process of creating a new parameter setting for the Config Generator page, we will use `train_batch_size` as an example. + +1. Determine the appropriate scope for the parameter. Currently, parameters are categorized into four files: + - `trinity/manager/config_registry/buffer_config_manager.py` + - `trinity/manager/config_registry/explorer_config_manager.py` + - `trinity/manager/config_registry/model_config_manager.py` + - `trinity/manager/config_registry/trainer_config_manager.py` + + In this case, `train_batch_size` should be placed in the `buffer_config_manager.py` file. + +2. Create a parameter setting function using Streamlit. The function name must follow the convention of starting with 'set_', and the remainder of the name becomes the config name. + +3. Decorate the parameter setting function with the `CONFIG_GENERATORS.register_config` decorator. This decorator requires the following information: + - Default value of the parameter + - Visibility condition (if applicable) + - Additional config parameters (if needed) + +```{note} +The `CONFIG_GENERATORS.register_config` decorator automatically passes `key=config_name` as an argument to the registered configuration function. Ensure that your function accepts this keyword argument. +``` + +For `train_batch_size`, we will use the following settings: +- Default value: 96 +- Visibility condition: `lambda: st.session_state["trainer_gpu_num"] > 0` +- Additional config: `{"_train_batch_size_per_gpu": 16}` + + +Here's the complete code for the `train_batch_size` parameter: + +```python +@CONFIG_GENERATORS.register_config( + default_value=96, + visible=lambda: st.session_state["trainer_gpu_num"] > 0, + other_configs={"_train_batch_size_per_gpu": 16}, +) +def set_train_batch_size(**kwargs): + key = kwargs.get("key") + trainer_gpu_num = st.session_state["trainer_gpu_num"] + st.session_state[key] = ( + st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"] + ) + + def on_change(): + st.session_state["_train_batch_size_per_gpu"] = max( + st.session_state[key] // st.session_state["trainer_gpu_num"], 1 + ) + + st.number_input( + "Train Batch Size", + min_value=trainer_gpu_num, + step=trainer_gpu_num, + help=_str_for_train_batch_size(), + on_change=on_change, + **kwargs, + ) +``` + +If the parameter requires validation, create a check function. For `train_batch_size`, we need to ensure it is divisible by `trainer_gpu_num`. If not, a warning should be displayed, and the parameter should be added to `unfinished_fields`. + +Decorate the check function with the `CONFIG_GENERATORS.register_check` decorator: + +```python +@CONFIG_GENERATORS.register_check() +def check_train_batch_size(unfinished_fields: set, key: str): + if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0: + unfinished_fields.add(key) + st.warning(_str_for_train_batch_size()) +``` + +```{note} +The `CONFIG_GENERATORS.register_check` decorator automatically receives `key=config_name` and `unfinished_fields=self.unfinished_fields` as arguments. Ensure your function accepts these keyword arguments. +``` + +### Step 2: Integrating New Parameters into `config_manager.py` + +To successfully integrate new parameters into the `config_manager.py` file, please adhere to the following procedure: + +1. Parameter Categorization: + Determine the appropriate section for the new parameter based on its functionality. The config generator page is structured into two primary modes: + - Beginner Mode: Comprises "Essential Configs" and "Important Configs" sections. + - Expert Mode: Includes "Model", "Buffer", "Explorer and Synchronizer", and "Trainer" sections. + +2. Parameter Addition: + Incorporate the new parameter into the relevant section using the `self.get_configs` method within the `ConfigManager` class. + + Example: + ```python + class ConfigManager: + def _expert_buffer_part(self): + self.get_configs("total_epochs", "train_batch_size") + ``` + +3. YAML File Integration: + Locate the appropriate position for the new parameter within the YAML file structure. This should be done in the `generate_config` function and its associated sub-functions. + +4. Parameter Value Assignment: + Utilize `st.session_state` to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML. + + Example: + ```python + class ConfigManager: + def _gen_buffer_config(self): + buffer_config = { + "batch_size": st.session_state["train_batch_size"], + # Additional configuration parameters + } + ``` + +By meticulously following these steps, you can ensure that new parameters are successfully added to the Config Generator page and properly integrated into the configuration system. This process maintains the integrity and functionality of the configuration management framework. + +--- + ## Check Code Style Before submitting the code, make sure it passes the code style check. Follow these steps: diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 5620c38f8e..751f1a0c30 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -1,6 +1,7 @@ import os import unittest +import ray import torch from trinity.buffer.reader.sql_reader import SQLReader @@ -22,6 +23,7 @@ def test_create_sql_buffer(self) -> None: algorithm_type="ppo", path=f"sqlite:///{db_path}", storage_type=StorageType.SQL, + wrap_in_ray=True, ) config = BufferConfig( max_retry_times=3, @@ -45,3 +47,5 @@ def test_create_sql_buffer(self) -> None: for _ in range(total_num // read_batch_size): exps = sql_reader.read() self.assertEqual(len(exps), read_batch_size) + db_wrapper = ray.get_actor("sql-test_buffer") + self.assertIsNotNone(db_wrapper) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 74b5d400e5..b0f354c4fa 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -12,7 +12,6 @@ get_unittest_dataset_config, ) from trinity.cli.launcher import explore -from trinity.common.constants import MonitorType class BaseExplorerCase(RayUnittestBase): @@ -23,7 +22,7 @@ def setUp(self): self.config.model.model_path = get_model_path() self.config.explorer.rollout_model.engine_type = "vllm_async" self.config.algorithm.repeat_times = 2 - self.config.monitor.monitor_type = MonitorType.TENSORBOARD + self.config.monitor.monitor_type = "tensorboard" self.config.project = "Trinity-unittest" self.config.checkpoint_root_dir = get_checkpoint_path() self.config.synchronizer.sync_interval = 2 diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 8a6e262a90..4c0e0349f5 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -2,7 +2,7 @@ import os import time import unittest -from typing import List +from typing import List, Tuple import ray import torch @@ -87,8 +87,8 @@ def init_process_group( def has_api_server(self) -> bool: return True - def api_server_ready(self) -> str: - return "http://localhosts:12345" + def api_server_ready(self) -> Tuple[str, str]: + return "http://localhosts:12345", "placeholder" class RunnerPoolTest(unittest.TestCase): diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 8cce2f9e85..0812fb5e6e 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock from tests.tools import get_unittest_dataset_config -from trinity.common.workflows import MathWorkflow +from trinity.common.workflows import MathWorkflow, Workflow from trinity.common.workflows.workflow import Task @@ -15,6 +15,33 @@ class MockResponse: reward: float = 0.0 +class DummyWorkflow(Workflow): + def __init__(self, model, task: Task, auxiliary_models=None): + super().__init__(model, task, auxiliary_models) + self.obj = task.raw_task + self.output_format = task.workflow_args["output_format"] + + @property + def resettable(self): + return True + + def reset(self, task: Task): + self.obj = task.raw_task + self.output_format = task.workflow_args["output_format"] + + def run(self): + if self.output_format == "json": + import json + + return [json.dumps(self.obj)] + elif self.output_format == "yaml": + import yaml + + return [yaml.safe_dump(self.obj)] + else: + raise ValueError("Invalid output format") + + class WorkflowTest(unittest.TestCase): def test_math_workflow(self) -> None: model = MagicMock() @@ -150,3 +177,18 @@ def test_gsm8k_workflow(self) -> None: self.assertEqual(experiences[1].reward, -0.1) self.assertEqual(experiences[2].reward, -0.1) self.assertEqual(experiences[3].reward, 1.1) + + def test_workflow_resettable(self) -> None: + model = MagicMock() + json_task = Task( + workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "json"} + ) + yaml_task = Task( + workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "yaml"} + ) + workflow = json_task.to_workflow(model) + answer = workflow.run() + self.assertEqual(answer[0], '{"a": 1}') + workflow.reset(yaml_task) + answer = workflow.run() + self.assertEqual(answer[0], "a: 1\n") diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 5b2795d952..bf064785cd 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -15,7 +15,7 @@ get_unittest_dataset_config, ) from trinity.cli.launcher import bench, both, train -from trinity.common.constants import MonitorType, SyncMethod +from trinity.common.constants import SyncMethod class BaseTrainerCase(RayUnittestBase): @@ -30,7 +30,7 @@ def setUp(self): self.config.explorer.rollout_model.use_v1 = False self.config.project = "Trainer-unittest" self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" - self.config.monitor.monitor_type = MonitorType.TENSORBOARD + self.config.monitor.monitor_type = "tensorboard" self.config.checkpoint_root_dir = get_checkpoint_path() self.config.synchronizer.sync_interval = 2 self.config.synchronizer.sync_method = SyncMethod.NCCL diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/utils/plugin_test.py b/tests/utils/plugin_test.py new file mode 100644 index 0000000000..01aa2f3967 --- /dev/null +++ b/tests/utils/plugin_test.py @@ -0,0 +1,34 @@ +import unittest +from pathlib import Path + +import ray + +from trinity.common.workflows import WORKFLOWS +from trinity.utils.plugin_loader import load_plugins + + +@ray.remote +class PluginActor: + def run(self): + my_plugin_cls = WORKFLOWS.get("my_workflow") + return my_plugin_cls(None, None).run() + + +class TestPluginLoader(unittest.TestCase): + def test_load_plugins(self): + ray.init(ignore_reinit_error=True) + my_plugin_cls = WORKFLOWS.get("my_workflow") + self.assertIsNone(my_plugin_cls) + load_plugins(Path(__file__).resolve().parent / "plugins") + my_plugin_cls = WORKFLOWS.get("my_workflow") + self.assertIsNotNone(my_plugin_cls) + my_plugin = my_plugin_cls(None, None, None) + self.assertTrue(my_plugin.__module__.startswith("trinity.plugins")) + res = my_plugin.run() + self.assertEqual(res[0], "Hello world") + self.assertEqual(res[1], "Hi") + remote_plugin = PluginActor.remote() + remote_res = ray.get(remote_plugin.run.remote()) + self.assertEqual(remote_res[0], "Hello world") + self.assertEqual(remote_res[1], "Hi") + ray.shutdown(_exiting_interpreter=True) diff --git a/tests/utils/plugins/__init__.py b/tests/utils/plugins/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/utils/plugins/my_workflow.py b/tests/utils/plugins/my_workflow.py new file mode 100644 index 0000000000..b999590a01 --- /dev/null +++ b/tests/utils/plugins/my_workflow.py @@ -0,0 +1,12 @@ +from typing import List + +from trinity.common.workflows import WORKFLOWS, Workflow + + +@WORKFLOWS.register_module("my_workflow") +class MyWorkflow(Workflow): + def __init__(self, model, task, auxiliary_models=None): + super().__init__(model, task, auxiliary_models) + + def run(self) -> List: + return ["Hello world", "Hi"] diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index 9d77dbb379..90f658f07c 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -46,7 +46,7 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig file_read_type = algorithm_type else: file_read_type = "rollout" - return FILE_READERS.get(file_read_type)(storage_config, buffer_config) + return FILE_READERS.get(file_read_type)(storage_config, buffer_config) # type: ignore else: raise ValueError(f"{storage_config.storage_type} not supported.") diff --git a/trinity/buffer/db_wrapper.py b/trinity/buffer/db_wrapper.py new file mode 100644 index 0000000000..977aaae493 --- /dev/null +++ b/trinity/buffer/db_wrapper.py @@ -0,0 +1,105 @@ +import time +from typing import List, Optional + +import ray +from sqlalchemy import asc, create_engine, desc +from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool + +from trinity.buffer.schema import Base, create_dynamic_table +from trinity.buffer.utils import retry_session +from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.constants import ReadStrategy +from trinity.utils.log import get_logger + + +class DBWrapper: + """ + A wrapper of a SQL database. + + If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor, + and provide a remote interface to the local database. + + For databases that do not support multi-processing read/write (e.g. sqlite, duckdb), we + recommend setting `wrap_in_ray` to `True` + """ + + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: + self.logger = get_logger(__name__) + self.engine = create_engine(storage_config.path, poolclass=NullPool) + self.table_model_cls = create_dynamic_table( + storage_config.algorithm_type, storage_config.name + ) + + try: + Base.metadata.create_all(self.engine, checkfirst=True) + except OperationalError: + self.logger.warning("Failed to create database, assuming it already exists.") + + self.session = sessionmaker(bind=self.engine) + self.batch_size = config.read_batch_size + self.max_retry_times = config.max_retry_times + self.max_retry_interval = config.max_retry_interval + + @classmethod + def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): + if storage_config.wrap_in_ray: + return ( + ray.remote(cls) + .options( + name=f"sql-{storage_config.name}", + get_if_exists=True, + ) + .remote(storage_config, config) + ) + else: + return cls(storage_config, config) + + def write(self, data: list) -> None: + with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: + experience_models = [self.table_model_cls.from_experience(exp) for exp in data] + session.add_all(experience_models) + + def read(self, strategy: Optional[ReadStrategy] = None) -> List: + if strategy is None: + strategy = ReadStrategy.LFU + + if strategy == ReadStrategy.LFU: + sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)) + + elif strategy == ReadStrategy.LRU: + sortOrder = (desc(self.table_model_cls.id),) + + elif strategy == ReadStrategy.PRIORITY: + sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id)) + + else: + raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage") + + exp_list = [] + while len(exp_list) < self.batch_size: + if len(exp_list): + self.logger.info("waiting for experiences...") + time.sleep(1) + with retry_session( + self.session, self.max_retry_times, self.max_retry_interval + ) as session: + # get a batch of experiences from the database + experiences = ( + session.query(self.table_model_cls) + .filter(self.table_model_cls.reward.isnot(None)) + .order_by(*sortOrder) # TODO: very slow + .limit(self.batch_size - len(exp_list)) + .with_for_update() + .all() + ) + # update the consumed field + for exp in experiences: + exp.consumed += 1 + exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences]) + self.logger.info(f"get {len(exp_list)} experiences:") + self.logger.info(f"reward = {[exp.reward for exp in exp_list]}") + self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}") + self.logger.info(f"first response_text = {exp_list[0].response_text}") + return exp_list diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index a360182f07..8490c44506 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -23,6 +23,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: if storage_config.path is not None and len(storage_config.path) > 0: sql_config = deepcopy(storage_config) sql_config.storage_type = StorageType.SQL + sql_config.wrap_in_ray = False self.sql_writer = SQLWriter(sql_config, self.config) else: self.sql_writer = None diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 58b762d3f2..8ba0ccea31 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -196,8 +196,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.reward_fn_key = meta.format.reward_fn_key self.task_type = meta.task_type - self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) - self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) + self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore + self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore self.total_epochs = meta.total_epochs if self.task_type == TaskType.EXPLORE else 1 def __len__(self): @@ -217,11 +217,12 @@ def read(self, strategy: Optional[ReadStrategy] = None): if self.reward_fn_key in sample else self.default_reward_fn_cls ) - assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required" + assert workflow_class is not None, "`default_workflow_type` or `workflow_key` is required" task = Task( workflow=workflow_class, format_args=self.meta.format, rollout_args=self.meta.rollout_args, + workflow_args=self.meta.workflow_args, is_eval=self.meta.task_type == TaskType.EVAL, reward_fn=reward_fn, raw_task=sample, diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index ffd013d4ef..3b26014fc4 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -16,13 +16,13 @@ class QueueReader(BufferReader): """Reader of the Queue buffer.""" - def __init__(self, meta: StorageConfig, config: BufferConfig): - assert meta.storage_type == StorageType.QUEUE + def __init__(self, storage_config: StorageConfig, config: BufferConfig): + assert storage_config.storage_type == StorageType.QUEUE self.config = config self.queue = QueueActor.options( - name=f"queue-{meta.name}", + name=f"queue-{storage_config.name}", get_if_exists=True, - ).remote(meta, config) + ).remote(storage_config, config) def read(self, strategy: Optional[ReadStrategy] = None) -> List: if strategy is not None and strategy != ReadStrategy.FIFO: diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index 4da2920816..dcd9d942bb 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -1,21 +1,13 @@ """Reader of the SQL buffer.""" -import time from typing import List, Optional -from sqlalchemy import asc, create_engine, desc -from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool +import ray from trinity.buffer.buffer_reader import BufferReader -from trinity.buffer.schema import Base, create_dynamic_table -from trinity.buffer.utils import retry_session +from trinity.buffer.db_wrapper import DBWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import ReadStrategy, StorageType -from trinity.utils.log import get_logger - -logger = get_logger(__name__) class SQLReader(BufferReader): @@ -23,57 +15,11 @@ class SQLReader(BufferReader): def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.SQL - self.engine = create_engine(meta.path, poolclass=NullPool) - - self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name) - try: - Base.metadata.create_all(self.engine, checkfirst=True) - except OperationalError: - logger.warning("Failed to create database, assuming it already exists.") - self.session = sessionmaker(bind=self.engine) - self.batch_size = config.read_batch_size - self.max_retry_times = config.max_retry_times - self.max_retry_interval = config.max_retry_interval + self.wrap_in_ray = meta.wrap_in_ray + self.db_wrapper = DBWrapper.get_wrapper(meta, config) def read(self, strategy: Optional[ReadStrategy] = None) -> List: - if strategy is None: - strategy = ReadStrategy.LFU - - if strategy == ReadStrategy.LFU: - sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)) - - elif strategy == ReadStrategy.LRU: - sortOrder = (desc(self.table_model_cls.id),) - - elif strategy == ReadStrategy.PRIORITY: - sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id)) - + if self.wrap_in_ray: + return ray.get(self.db_wrapper.read.remote(strategy)) else: - raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage") - - exp_list = [] - while len(exp_list) < self.batch_size: - if len(exp_list): - logger.info("waiting for experiences...") - time.sleep(1) - with retry_session( - self.session, self.max_retry_times, self.max_retry_interval - ) as session: - # get a batch of experiences from the database - experiences = ( - session.query(self.table_model_cls) - .filter(self.table_model_cls.reward.isnot(None)) - .order_by(*sortOrder) # TODO: very slow - .limit(self.batch_size - len(exp_list)) - .with_for_update() - .all() - ) - # update the consumed field - for exp in experiences: - exp.consumed += 1 - exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences]) - logger.info(f"get {len(exp_list)} experiences:") - logger.info(f"reward = {[exp.reward for exp in exp_list]}") - logger.info(f"first prompt_text = {exp_list[0].prompt_text}") - logger.info(f"first response_text = {exp_list[0].response_text}") - return exp_list + return self.db_wrapper.read(strategy) diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index e0b0bdf640..3e054d58c6 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -1,19 +1,12 @@ """Writer of the SQL buffer.""" -from sqlalchemy import create_engine -from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool +import ray from trinity.algorithm.algorithm import ALGORITHM_TYPE from trinity.buffer.buffer_writer import BufferWriter -from trinity.buffer.schema import Base, create_dynamic_table -from trinity.buffer.utils import retry_session +from trinity.buffer.db_wrapper import DBWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType -from trinity.utils.log import get_logger - -logger = get_logger(__name__) class SQLWriter(BufferWriter): @@ -25,23 +18,14 @@ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: # TODO: support other algorithms algorithm = ALGORITHM_TYPE.get(meta.algorithm_type) assert algorithm.use_rollout, "Only RFT buffer is supported for writing." - self.engine = create_engine(meta.path, poolclass=NullPool) - self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name) - - try: - Base.metadata.create_all(self.engine, checkfirst=True) - except OperationalError: - logger.warning("Failed to create database, assuming it already exists.") - - self.session = sessionmaker(bind=self.engine) - self.batch_size = config.read_batch_size - self.max_retry_times = config.max_retry_times - self.max_retry_interval = config.max_retry_interval + self.wrap_in_ray = meta.wrap_in_ray + self.db_wrapper = DBWrapper.get_wrapper(meta, config) def write(self, data: list) -> None: - with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: - experience_models = [self.table_model_cls.from_experience(exp) for exp in data] - session.add_all(experience_models) + if self.wrap_in_ray: + ray.get(self.db_wrapper.write.remote(data)) + else: + self.db_wrapper.write(data) def finish(self) -> None: # TODO: implement this diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 6a01bfb688..cf4a7882aa 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -11,6 +11,7 @@ from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger +from trinity.utils.plugin_loader import load_plugins logger = get_logger(__name__) @@ -131,7 +132,8 @@ def activate_data_module(data_workflow_url: str, config_path: str): return -def run(config_path: str, dlc: bool = False): +def run(config_path: str, dlc: bool = False, plugin_dir: str = None): + load_plugins(plugin_dir) config = load_config(config_path) config.check_and_update() pprint(config) @@ -161,6 +163,11 @@ def run(config_path: str, dlc: bool = False): elif config.mode == "bench": bench(config) + if dlc: + from trinity.utils.dlc_utils import stop_ray_cluster + + stop_ray_cluster() + def studio(port: int = 8501): from streamlit.web import cli as stcli @@ -188,6 +195,12 @@ def main() -> None: # run command run_parser = subparsers.add_parser("run", help="Run RFT process.") run_parser.add_argument("--config", type=str, required=True, help="Path to the config file.") + run_parser.add_argument( + "--plugin-dir", + type=str, + default=None, + help="Path to the directory containing plugin modules.", + ) run_parser.add_argument( "--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC." ) @@ -198,12 +211,10 @@ def main() -> None: "--port", type=int, default=8501, help="The port for Trinity-Studio." ) - # TODO: add more commands like `monitor`, `label` - args = parser.parse_args() if args.command == "run": # TODO: support parse all args from command line - run(args.config, args.dlc) + run(args.config, args.dlc, args.plugin_dir) elif args.command == "studio": studio(args.port) diff --git a/trinity/common/config.py b/trinity/common/config.py index dd863edbd3..22d8f3d711 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -7,7 +7,6 @@ from omegaconf import OmegaConf from trinity.common.constants import ( - MonitorType, PromptType, ReadStrategy, StorageType, @@ -77,10 +76,14 @@ class StorageConfig: format: FormatConfig = field(default_factory=FormatConfig) index: int = 0 + # used for StorageType.SQL + wrap_in_ray: bool = True + # used for rollout tasks default_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None rollout_args: GenerationConfig = field(default_factory=GenerationConfig) + workflow_args: dict = field(default_factory=dict) # ! DO NOT SET, automatically set from algorithm.algorithm_type algorithm_type: Optional[str] = None @@ -303,8 +306,10 @@ class TrainerConfig: @dataclass class MonitorConfig: - # TODO: support multiple monitors (List[MonitorType]) - monitor_type: MonitorType = MonitorType.WANDB + # TODO: support multiple monitors (List[str]) + monitor_type: str = "tensorboard" + # the default args for monitor + monitor_args: Dict = field(default_factory=dict) # ! DO NOT SET, automatically generated as checkpoint_job_dir/monitor cache_dir: str = "" diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 47b04f853b..3c49d65c21 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -23,6 +23,9 @@ def __getattr__(cls, name): return cls[name.upper()] return super().__getattr__(name) + def __call__(cls, value, *args, **kwargs): + return super().__call__(value.lower(), *args, **kwargs) + class CaseInsensitiveEnum(Enum, metaclass=CaseInsensitiveEnumMeta): pass @@ -47,11 +50,11 @@ class ReadStrategy(CaseInsensitiveEnum): """Pop Strategy.""" DEFAULT = None - FIFO = "FIFO" - RANDOM = "RANDOM" - LRU = "LRU" - LFU = "LFU" - PRIORITY = "PRIORITY" + FIFO = "fifo" + RANDOM = "random" + LRU = "lru" + LFU = "lfu" + PRIORITY = "priority" class StorageType(CaseInsensitiveEnum): diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 25cb927799..fd5670b390 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -64,9 +64,9 @@ def create_inference_models( else: raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}") - main_bundles = [{"GPU": 1, "CPU": 1} for _ in range(engine_num * tensor_parallel_size)] + main_bundles = [{"GPU": 1} for _ in range(engine_num * tensor_parallel_size)] auxiliary_bundles = [ - {"GPU": 1, "CPU": 1} + {"GPU": 1} for _ in range( sum( [ @@ -103,6 +103,7 @@ def create_inference_models( num_gpus=0 if config.explorer.rollout_model.tensor_parallel_size > 1 else 1, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, + placement_group_capture_child_tasks=True, placement_group_bundle_index=bundles_for_engine[0], ), ) @@ -121,6 +122,7 @@ def create_inference_models( bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size) model_config.enable_openai_api = True model_config.engine_type = "vllm_async" + model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) engines.append( ray.remote(vLLMAysncRolloutModel) .options( @@ -128,6 +130,7 @@ def create_inference_models( num_gpus=0 if model_config.tensor_parallel_size > 1 else 1, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, + placement_group_capture_child_tasks=True, placement_group_bundle_index=bundles_for_engine[0], ), ) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index b5104f2cc7..cb15b1ae3d 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -103,6 +103,11 @@ def get_ckp_version(self) -> int: return ray.get(self.model.get_ckp_version.remote()) def get_openai_client(self) -> openai.OpenAI: + """Get the openai client. + + Returns: + openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path. + """ if self.openai_client is not None: return self.openai_client if not ray.get(self.model.has_api_server.remote()): @@ -110,9 +115,9 @@ def get_openai_client(self) -> openai.OpenAI: "OpenAI API server is not running on current model." "Please set `enable_openai_api` to `True`." ) - api_address = None + api_address, model_path = None, None while True: - api_address = ray.get(self.model.api_server_ready.remote()) + api_address, model_path = ray.get(self.model.api_server_ready.remote()) if api_address is not None: break else: @@ -127,4 +132,5 @@ def get_openai_client(self) -> openai.OpenAI: base_url=api_address, api_key="EMPTY", ) + setattr(self.openai_client, "model_path", model_path) # TODO: may be removed return self.openai_client diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 02ea52ec58..27faa4c44a 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -5,7 +5,7 @@ import os import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import aiohttp import torch @@ -319,26 +319,30 @@ async def run_api_server(self): async def has_api_server(self) -> bool: return self.config.enable_openai_api - async def api_server_ready(self) -> Optional[str]: + async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]: """Check if the OpenAI API server is ready. Returns: - str: The URL of the OpenAI API server. + api_url (str): The URL of the OpenAI API server. + model_path (str): The path of the model. """ if not await self.has_api_server(): - return None + return None, None try: async with aiohttp.ClientSession() as session: async with session.get( f"http://{self.api_server_host}:{self.api_server_port}/health" ) as response: if response.status == 200: - return f"http://{self.api_server_host}:{self.api_server_port}/v1" + return ( + f"http://{self.api_server_host}:{self.api_server_port}/v1", + self.config.model_path, + ) else: - return None + return None, None except Exception as e: self.logger.error(e) - return None + return None, None async def reset_prefix_cache(self) -> None: await self.async_llm.reset_prefix_cache() diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 92bf29a64e..f5b1c9a7b9 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -3,10 +3,11 @@ from .envs.alfworld.alfworld_workflow import AlfworldWorkflow from .envs.sciworld.sciworld_workflow import SciWorldWorkflow from .envs.webshop.webshop_workflow import WebShopWorkflow -from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task +from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow __all__ = [ "Task", + "Workflow", "WORKFLOWS", "SimpleWorkflow", "MathWorkflow", diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 9786bd6b77..fc4a87556b 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -28,8 +28,9 @@ class Task: """A Task class that defines a task and its associated reward function / workflow.""" workflow: Type[Workflow] - format_args: FormatConfig + format_args: FormatConfig = field(default_factory=FormatConfig) rollout_args: GenerationConfig = field(default_factory=GenerationConfig) + workflow_args: dict = field(default_factory=dict) is_eval: bool = False reward_fn: Optional[Type[RewardFn]] = None raw_task: Optional[dict] = None # The raw data sample @@ -41,6 +42,10 @@ def to_workflow( Args: model (ModelWrapper): The rollout model for the workflow. + auxiliary_models (List[openai.OpenAI]): The auxiliary models for the workflow. + + Note: + `model_path` attribute is added to the `auxiliary_models` for use within the workflow. Returns: Workflow: The generated workflow object. diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 37257f71ce..0a897254fb 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -21,7 +21,7 @@ from trinity.explorer.runner_pool import RunnerPool from trinity.manager.manager import CacheManager from trinity.utils.log import get_logger -from trinity.utils.monitor import Monitor +from trinity.utils.monitor import MONITOR @ray.remote(name="explorer", concurrency_groups={"get_weight": 32, "setup_weight_sync_group": 1}) @@ -49,7 +49,7 @@ def __init__(self, config: Config): for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets: self.eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer)) self.runner_pool = self._init_runner_pool() - self.monitor = Monitor( + self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, name=self.config.name, role="explorer", diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 9ac2d36f16..80b8992b3b 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -7,22 +7,15 @@ import streamlit as st import yaml -from trinity.common.constants import ( - AlgorithmType, - MonitorType, - PromptType, - StorageType, - SyncMethod, -) -from trinity.common.rewards import REWARD_FUNCTIONS -from trinity.common.workflows.workflow import WORKFLOWS -from trinity.trainer.verl.ray_trainer import AdvantageEstimator +from trinity.common.constants import AlgorithmType, StorageType +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +from trinity.manager.config_registry.trainer_config_manager import use_critic class ConfigManager: def __init__(self): - self._init_default_config() self.unfinished_fields = set() + CONFIG_GENERATORS.set_unfinished_fields(self.unfinished_fields) st.set_page_config(page_title="Trinity-RFT Config Generator", page_icon=":robot:") st.title("Trinity-RFT Config Generator") if "_init_config_manager" not in st.session_state: @@ -44,1319 +37,256 @@ def __init__(self): st.session_state.is_running = False self.generate_config() - def _init_default_config(self): - self.default_config = { - "_init_config_manager": True, - "mode": "both", - "project": "Trinity-RFT", - "exp_name": "qwen2.5-1.5B", - "checkpoint_root_dir": "", - "monitor_type": MonitorType.TENSORBOARD.value, - # Algorithm Configs - "algorithm_type": AlgorithmType.PPO.value, - "_grouped_adv_repeat_times": 2, - "_not_grouped_adv_repeat_times": 1, - "repeat_times": 1, - "gamma": 1.0, - "lam": 1.0, - # Model Configs - "model_path": "", - "critic_model_path": "", - "max_prompt_tokens": 1024, - "max_response_tokens": 1024, - # Cluster Config - "node_num": 1, - "gpu_per_node": 8, - "total_gpu_num": 8, - "trainer_gpu_num": 6, - # Buffer Configs - "total_epochs": 20, - "_train_batch_size_per_gpu": 16, - "train_batch_size": 96, - "buffer_max_retry_times": 3, - "max_retry_interval": 1, - # Taskset Configs - "taskset_path": "", - "taskset_subset_name": None, - "taskset_split": "train", - "taskset_prompt_key": "question", - "taskset_response_key": "answer", - "temperature": 1.0, - "top_p": 1.0, # TODO: to be used - "top_k": -1, # TODO: to be used - "logprobs": 0, - # Eval Taskset Configs - "_eval_tasksets_num": 0, - # Explorer Input Configs - "default_workflow_type": "math_workflow", - "default_reward_fn_type": "math_reward", - "system_prompt": None, - "reply_prefix": None, - # Experience Buffer / DPO Dataset Configs - "_dpo_storage_type": StorageType.FILE.value, - "_not_dpo_storage_type": StorageType.QUEUE.value, - "storage_type": StorageType.QUEUE.value, - "_dpo_experience_buffer_path": "", - "_not_dpo_experience_buffer_path": "", - "experience_buffer_path": "", - "dpo_dataset_train_split": "train", - "dpo_dataset_prompt_type": PromptType.MESSAGES.value, - "dpo_dataset_prompt_key": "prompt", - "dpo_dataset_chosen_key": "chosen", - "dpo_dataset_rejected_key": "rejected", - # SFT Warmup Dataset Configs - "sft_warmup_dataset_path": "", - "sft_warmup_train_split": "train", - "sft_warmup_prompt_type": PromptType.MESSAGES.value, - "sft_warmup_messages_key": "messages", - "sft_warmup_prompt_key": "prompt", - "sft_warmup_response_key": "response", - # TrainerInput Configs - # TODO: read_experience_strategy - "sft_warmup_steps": 0, - # Explorer and Sync Configs - "runner_num": 32, - "max_timeout": 900, - "explorer_max_retry_times": 2, - "eval_interval": 1000, - "eval_on_latest_checkpoint": True, - # Rollout Model Configs - "engine_type": "vllm_async", - "engine_num": 2, - "tensor_parallel_size": 1, - "use_v1": True, - "enforce_eager": True, - "enable_prefix_caching": False, - "enable_chunked_prefill": False, - "gpu_memory_utilization": 0.9, - "dtype": "bfloat16", - "seed": 42, - # TODO: max_prompt_tokens - # TODO: max_response_tokens - # TODO: chat_template - "enable_thinking": False, - "enable_openai_api": False, - # TODO: Auxiliary Models Configs - # Synchronizer Configs - "_not_dpo_sync_method": SyncMethod.NCCL.value, - "sync_method": SyncMethod.NCCL.value, - "sync_interval": 10, - "sync_timeout": 1200, - # Trainer Configs - "trainer_type": "verl", - "_nccl_save_interval": 100, - "save_interval": 100, - # TODO: enable_preview - "_not_dpo_actor_use_kl_loss": True, - "actor_use_kl_loss": True, - "actor_kl_loss_coef": 0.001, - "actor_entropy_coef": 0.001, - "actor_grad_clip": 1.0, - "actor_clip_ratio": 0.2, - # veRL Trainer Configs - "training_args": [ - "balance_batch", - "gradient_checkpointing", - "remove_padding", - "dynamic_bsz", - ], - "ppo_epochs": 1, - "training_strategy": "fsdp", - "param_offload": False, - "optimizer_offload": False, - "resume_mode": "auto", - "resume_from_path": "", - "critic_warmup": 0, - "total_training_steps": None, - "default_hdfs_dir": None, - "remove_previous_ckpt_in_save": False, - "del_local_ckpt_after_load": False, - "max_actor_ckpt_to_keep": None, - "max_critic_ckpt_to_keep": None, - "adv_estimator": "gae", - "norm_adv_by_std_in_grpo": True, - "use_kl_in_reward": False, - "kl_penalty": "low_var_kl", - "kl_ctrl_type": "fixed", - "kl_ctrl_coef": 0.001, - "horizon": 10000, - "target_kl": 0.1, - "actor_ppo_micro_batch_size_per_gpu": 4, - "ref_log_prob_micro_batch_size_per_gpu": 8, - "actor_ulysses_sequence_parallel_size": 1, - "actor_lr": 1e-6, - "actor_warmup_style": "constant", - "actor_lr_warmup_steps_ratio": 0.0, - "actor_tau": 0.0, - "actor_opmd_baseline": "mean", - "actor_use_uid": False, - "actor_kl_loss_type": "low_var_kl", - "actor_checkpoint": ["model", "hf_model", "optimizer", "extra"], - "critic_lr": 1e-6, - "critic_warmup_style": "constant", - "critic_lr_warmup_steps_ratio": 0.0, - "critic_grad_clip": 1.0, - "critic_cliprange_value": 0.5, - "critic_ppo_micro_batch_size_per_gpu": 8, - "critic_ulysses_sequence_parallel_size": 1, - "critic_checkpoint": ["model", "optimizer", "extra"], - } - def reset_session_state(self): - for key, value in self.default_config.items(): + st.session_state["_init_config_manager"] = True + for key, value in CONFIG_GENERATORS.default_config.items(): st.session_state[key] = value def maintain_session_state(self): - for key in self.default_config: + st.session_state["_init_config_manager"] = True + for key in CONFIG_GENERATORS.default_config: st.session_state[key] = st.session_state[key] - eavl_dataset_keys = ["name", "path", "subset_name", "split", "prompt_key", "response_key"] + + eval_dataset_keys = [ + "name", + "path", + "subset_name", + "split", + "prompt_key", + "response_key", + "temperature", + "logprobs", + "n", + ] + last_idx, del_num = 0, 0 for idx in range(st.session_state["_eval_tasksets_num"]): - for key in eavl_dataset_keys: + if st.session_state.get(f"eval_taskset_{idx}_del_flag", False): + del_num += 1 + continue + for key in eval_dataset_keys: full_key = f"eval_taskset_{idx}_{key}" - st.session_state[full_key] = st.session_state[full_key] - - def _set_project(self): - st.text_input("Project", key="project") - - def _set_exp_name(self): - st.text_input("Experiment Name", key="exp_name") - - def _set_monitor_type(self): - st.selectbox( - "Monitor Type", - options=[monitor_type.value for monitor_type in MonitorType], - key="monitor_type", - ) - - def _set_model_path(self): - st.text_input("Model Path", key="model_path") - if not st.session_state["model_path"].strip(): - self.unfinished_fields.add("model_path") - st.warning("Please input model path.") - - def _set_critic_model_path(self): - if st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value: - st.text_input( - "Critic Model Path (defaults to `model_path`)", - key="critic_model_path", - ) - - def _set_checkpoint_root_dir(self): - st.text_input("Checkpoint Root Dir", key="checkpoint_root_dir") - if not st.session_state["checkpoint_root_dir"].strip(): # TODO: may auto generate - self.unfinished_fields.add("checkpoint_root_dir") - st.warning("Please input checkpoint root dir.") - elif not os.path.isabs(st.session_state["checkpoint_root_dir"].strip()): - self.unfinished_fields.add("checkpoint_root_dir") - st.warning("Please input an absolute path.") - - def _set_node_num(self): - st.number_input("Node Num", key="node_num", min_value=1, on_change=self._set_total_gpu_num) - - def _set_gpu_per_node(self): - st.number_input( - "GPU Per Node", - key="gpu_per_node", - min_value=1, - max_value=8, - on_change=self._set_total_gpu_num, - ) - - def _set_total_gpu_num(self): - st.session_state["total_gpu_num"] = ( - st.session_state["gpu_per_node"] * st.session_state["node_num"] - ) - self._set_trainer_gpu_num() - - def _set_trainer_gpu_num(self): - if st.session_state["mode"] == "both": - st.session_state["trainer_gpu_num"] = ( - st.session_state["total_gpu_num"] - - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"] - ) - else: # model == train - st.session_state["trainer_gpu_num"] = st.session_state["total_gpu_num"] - - def _set_max_prompt_tokens(self): - st.number_input("Max Prompt Tokens", key="max_prompt_tokens", min_value=1) - - def _set_max_response_tokens(self): - st.number_input("Max Response Tokens", key="max_response_tokens", min_value=1) - - def _set_total_epochs(self): - st.number_input("Total Epochs", key="total_epochs", min_value=1) - - @property - def _str_for_train_batch_size(self): - trainer_gpu_num_str = ( - "`gpu_per_node * node_num - engine_num * tensor_parallel_size`" - if st.session_state["mode"] == "both" - else "`gpu_per_node * node_num`" - ) - return ( - f"Please ensure that `train_batch_size` can be divided by " - f"{trainer_gpu_num_str} = {st.session_state['trainer_gpu_num']}." - ) - - def _set_train_batch_size(self): - trainer_gpu_num = st.session_state["trainer_gpu_num"] - st.session_state["train_batch_size"] = ( - st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"] - ) - - def on_change(): - st.session_state["_train_batch_size_per_gpu"] = max( - st.session_state["train_batch_size"] // st.session_state["trainer_gpu_num"], 1 - ) - - st.number_input( - "Train Batch Size", - key="train_batch_size", - min_value=trainer_gpu_num, - step=trainer_gpu_num, - help=self._str_for_train_batch_size, - on_change=on_change, - ) - - def _check_train_batch_size(self): - if st.session_state["train_batch_size"] % st.session_state["trainer_gpu_num"] != 0: - self.unfinished_fields.add("train_batch_size") - st.warning(self._str_for_train_batch_size) - - def _set_taskset_path(self): - st.text_input("Taskset Path", key="taskset_path") - if not st.session_state["taskset_path"].strip(): - self.unfinished_fields.add("taskset_path") - st.warning("Please input taskset path.") - - def _set_system_prompt(self): - st.text_area( - "System Prompt", - key="system_prompt", - placeholder="System prompt is used to guide the model behavior.", - ) - - def _set_reply_prefix(self): - st.text_area( - "Assistant Reply Prefix", - key="reply_prefix", - placeholder="""Assistant reply prefix is used to specify the initial content of model reply, """ - """and a common setting is: \nLet me solve this step by step. """, - ) - - def _set_taskset_args(self): - if st.session_state["taskset_path"] and "://" not in st.session_state["taskset_path"]: - subset_name_col, split_col = st.columns(2) - subset_name_col.text_input( - "Subset Name :orange-badge[(Needs review)]", - key="taskset_subset_name", - help="The subset name used for `datasets.load_datasets`, see " - "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", - ) - split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split") - prompt_key_col, response_key_col = st.columns(2) - prompt_key_col.text_input( - "Prompt Key :orange-badge[(Needs review)]", key="taskset_prompt_key" - ) - response_key_col.text_input( - "Response Key :orange-badge[(Needs review)]", key="taskset_response_key" - ) - self._set_configs_with_st_columns(["temperature", "logprobs"]) - - def _set_eval_taskset_idx(self, idx): # TODO: add delete - st.text_input( - "Taskset Name", - key=f"eval_taskset_{idx}_name", - ) - st.text_input( - "Eval Taskset Path", - key=f"eval_taskset_{idx}_path", - ) - if not st.session_state[f"eval_taskset_{idx}_path"].strip(): - st.warning("Please input the taskset path, or it will be ignored.") - subset_name_col, split_col = st.columns(2) - subset_name_col.text_input( - "Subset Name :orange-badge[(Needs review)]", - key=f"eval_taskset_{idx}_subset_name", - help="The subset name used for `datasets.load_datasets`, see " - "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", - ) - split_col.text_input( - "Eval Split :orange-badge[(Needs review)]", - key=f"eval_taskset_{idx}_split", - ) - prompt_key_col, response_key_col = st.columns(2) - prompt_key_col.text_input( - "Prompt Key :orange-badge[(Needs review)]", - key=f"eval_taskset_{idx}_prompt_key", - ) - response_key_col.text_input( - "Response Key :orange-badge[(Needs review)]", - key=f"eval_taskset_{idx}_response_key", - ) - - def _set_eval_tasksets(self): - if st.button("Add Eval Taskset"): - st.session_state["_eval_tasksets_num"] += 1 - if st.session_state["_eval_tasksets_num"] > 0: - tabs = st.tabs( - [f"Eval Taskset {i + 1}" for i in range(st.session_state["_eval_tasksets_num"])] - ) - for idx, tab in enumerate(tabs): - with tab: - self._set_eval_taskset_idx(idx) - - def _set_default_workflow_type(self): - st.selectbox( - "Default Workflow Type :orange-badge[(Needs review)]", - WORKFLOWS.modules.keys(), - key="default_workflow_type", - help=r"""`simple_workflow`: call 'model.chat()' to get responses. - -`math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses. - -Other workflows: conduct multi-turn task for the given dataset. -""", - ) - - def _set_default_reward_fn_type(self): - st.selectbox( - "Default Reward Fn Type :orange-badge[(Needs review)]", - REWARD_FUNCTIONS.modules.keys(), - key="default_reward_fn_type", - help=r"""`accuracy_reward`: check the accuracy for math problems. - -`format_reward`: check if the response matches the format (default: `** *`). - -`math_reward`: `accuracy_reward` (1 or 0) + `format_reward` (+0.1 or -0.1). -""", - ) - - def _set_storage_type(self): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["storage_type"] = st.session_state["_dpo_storage_type"] - storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] - else: - st.session_state["storage_type"] = st.session_state["_not_dpo_storage_type"] - storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value] - - def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["_dpo_storage_type"] = st.session_state["storage_type"] - else: - st.session_state["_not_dpo_storage_type"] = st.session_state["storage_type"] - - st.selectbox( - "Storage Type", - storage_candidates, - key="storage_type", - on_change=on_change, - ) - - def _set_experience_buffer_path(self): # TODO - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["experience_buffer_path"] = st.session_state[ - "_dpo_experience_buffer_path" - ] - title = "DPO Dataset Path" - help_msg = r"""This path to DPO dataset, - -if `storage_type == StorageType.FILE`, this should be a path to a file, - -if `storage_type == StorageType.SQL`, this should be a path to database.""" - else: - st.session_state["experience_buffer_path"] = st.session_state[ - "_not_dpo_experience_buffer_path" - ] - title = "Experience Buffer Path" - help_msg = r"""This path is used for `trainer`, - -if `storage_type == StorageType.QUEUE`, default to `None`, - -if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_root_dir, '.cache', project_name, experiment_name)}/data.db`.""" - - def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["_dpo_experience_buffer_path"] = st.session_state[ - "experience_buffer_path" - ] - else: - st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[ - "experience_buffer_path" - ] - - st.text_input( - title, - key="experience_buffer_path", - help=help_msg, - on_change=on_change, - ) - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - if not st.session_state["experience_buffer_path"].strip(): - self.unfinished_fields.add("experience_buffer_path") - st.warning("Please input DPO dataset path.") - - def _set_buffer_max_retry_times(self): - st.number_input("Max Retry Times", key="buffer_max_retry_times", min_value=1) - - def _set_max_retry_interval(self): - st.number_input("Max Retry Interval", key="max_retry_interval", min_value=1) - - def _set_dpo_dataset_kwargs(self): - dpo_dataset_train_split_col, dpo_dataset_prompt_type_col = st.columns(2) - dpo_dataset_train_split_col.text_input( - "DPO Dataset Train Split :orange-badge[(Needs review)]", key="dpo_dataset_train_split" - ) - dpo_dataset_prompt_type_col.selectbox( - "DPO Dataset Prompt Type :orange-badge[(Needs review)]", - [prompt_type.value for prompt_type in PromptType], - key="dpo_dataset_prompt_type", - ) - - ( - dpo_dataset_prompt_key_col, - dpo_dataset_chosen_key_col, - dpo_dataset_rejected_key_col, - ) = st.columns(3) - dpo_dataset_prompt_key_col.text_input( - "DPO Dataset Prompt Key :orange-badge[(Needs review)]", key="dpo_dataset_prompt_key" - ) - dpo_dataset_chosen_key_col.text_input( - "DPO Dataset Chosen Key :orange-badge[(Needs review)]", key="dpo_dataset_chosen_key" - ) - dpo_dataset_rejected_key_col.text_input( - "DPO Dataset Rejected Key :orange-badge[(Needs review)]", - key="dpo_dataset_rejected_key", - ) - - def _check_sft_warmup_dataset_path(self): - if st.session_state["sft_warmup_steps"]: - if not st.session_state["sft_warmup_dataset_path"].strip(): - self.unfinished_fields.add("sft_warmup_dataset_path") - st.warning("Please input SFT warmup dataset path when `sft_warmup_steps` is not 0") - - def _set_sft_warmup_dataset_path(self): - st.text_input("SFT Warmup Dataset Path", key="sft_warmup_dataset_path") - self._check_sft_warmup_dataset_path() - - def _set_sft_warmup_dataset_args(self): - if ( - st.session_state["sft_warmup_dataset_path"] - and "://" not in st.session_state["sft_warmup_dataset_path"] - ): # TODO - ( - sft_warmup_train_split_col, - sft_warmup_prompt_type_col, - ) = st.columns(2) - sft_warmup_train_split_col.text_input( - "SFT Dataset Train Split :orange-badge[(Needs review)]", - key="sft_warmup_train_split", - ) - sft_warmup_prompt_type_col.selectbox( - "SFT Dataset Prompt Type :orange-badge[(Needs review)]", - [prompt_type.value for prompt_type in PromptType], - key="sft_warmup_prompt_type", - ) - ( - sft_warmup_messages_key_col, - sft_warmup_prompt_key_col, - sft_warmup_response_key_col, - ) = st.columns( - 3 - ) # TODO: select by prompt type - sft_warmup_messages_key_col.text_input( - "SFT Dataset Messages Key :orange-badge[(Needs review)]", - key="sft_warmup_messages_key", - ) - sft_warmup_prompt_key_col.text_input( - "SFT Dataset Prompt Key :orange-badge[(Needs review)]", key="sft_warmup_prompt_key" - ) - sft_warmup_response_key_col.text_input( - "SFT Dataset Response Key :orange-badge[(Needs review)]", - key="sft_warmup_response_key", - ) - - def _set_engine_type(self): - st.selectbox("Explorer Engine Type", ["vllm_async", "vllm"], key="engine_type") - - @property - def _str_for_engine_num_and_tp_size(self): - return r"""and it must meet the following constraints: -```python -assert engine_num * tensor_parallel_size < gpu_per_node * node_num -if node_num > 1: - assert gpu_per_node % tensor_parallel_size == 0 - assert engine_num * tensor_parallel_size % gpu_per_node == 0 -```""" - - def _set_engine_num(self): - total_gpu_num = st.session_state["total_gpu_num"] - max_engine_num = (total_gpu_num - 1) // st.session_state["tensor_parallel_size"] - if st.session_state["engine_num"] > max_engine_num: - st.session_state["engine_num"] = max_engine_num - self._set_trainer_gpu_num() - st.number_input( - "Engine Num", - key="engine_num", - min_value=1, - max_value=max_engine_num, - help=f"`engine_num` is used to set the quantity of inference engines, " - f"{self._str_for_engine_num_and_tp_size}", - on_change=self._set_trainer_gpu_num, - ) - - def _set_tensor_parallel_size(self): - total_gpu_num = st.session_state["total_gpu_num"] - max_tensor_parallel_size = (total_gpu_num - 1) // st.session_state["engine_num"] - if st.session_state["tensor_parallel_size"] > max_tensor_parallel_size: - st.session_state["tensor_parallel_size"] = max_tensor_parallel_size - self._set_trainer_gpu_num() - st.number_input( - "Tensor Parallel Size", - key="tensor_parallel_size", - min_value=1, - max_value=max_tensor_parallel_size, - help=f"`tensor_parallel_size` is used to set the tensor parallel size of inference engines, " - f"{self._str_for_engine_num_and_tp_size}", - on_change=self._set_trainer_gpu_num, - ) - - def _check_engine_num_and_tp_size(self): - node_num = st.session_state["node_num"] - gpu_per_node = st.session_state["gpu_per_node"] - engine_num = st.session_state["engine_num"] - tensor_parallel_size = st.session_state["tensor_parallel_size"] - if node_num > 1: - if gpu_per_node % tensor_parallel_size != 0: - self.unfinished_fields.add("tensor_parallel_size") - st.warning( - "Please ensure that `tensor_parallel_size` is a factor of `gpu_per_node` when `node_num > 1`." - ) - if engine_num * tensor_parallel_size % gpu_per_node != 0: - self.unfinished_fields.add("engine_num") - st.warning( - "Please ensure that `engine_num * tensor_parallel_size` can be divided by `gpu_per_node` when `node_num > 1`." - ) - - def _set_repeat_times(self): # TODO - grouped_adv_algorithms = [ - AlgorithmType.GRPO.value, - AlgorithmType.OPMD.value, # TODO: may add rloo + last_full_key = f"eval_taskset_{last_idx}_{key}" + st.session_state[last_full_key] = st.session_state[full_key] + last_idx += 1 + st.session_state["_eval_tasksets_num"] -= del_num + + auxiliary_model_keys = [ + "model_path", + "engine_type", + "engine_num", + "tensor_parallel_size", + "gpu_memory_utilization", + "dtype", + "seed", + "use_v1", + "enforce_eager", + "enable_prefix_caching", + "enable_chunked_prefill", + "enable_thinking", + "enable_openai_api", ] - if st.session_state["algorithm_type"] in grouped_adv_algorithms: - min_repeat_times = 2 - st.session_state["repeat_times"] = st.session_state["_grouped_adv_repeat_times"] - else: - min_repeat_times = 1 - st.session_state["repeat_times"] = st.session_state["_not_grouped_adv_repeat_times"] - - def on_change(): - if st.session_state["algorithm_type"] in grouped_adv_algorithms: - st.session_state["_grouped_adv_repeat_times"] = st.session_state["repeat_times"] - else: - st.session_state["_not_grouped_adv_repeat_times"] = st.session_state["repeat_times"] - - st.number_input( - "Repeat Times", - key="repeat_times", - min_value=min_repeat_times, - help="`repeat_times` is used to set how many experiences each task can generate, " - "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", - on_change=on_change, - ) - - def _set_sync_method(self): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["sync_method"] = SyncMethod.CHECKPOINT.value - disabled = True - else: - st.session_state["sync_method"] = st.session_state["_not_dpo_sync_method"] - disabled = False - - def on_change(): - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: - st.session_state["_not_dpo_sync_method"] = st.session_state["sync_method"] - - st.selectbox( - "Sync Method", - [sync_method.value for sync_method in SyncMethod], - key="sync_method", - help="""`nccl`: the explorer and trainer sync model weights once every `sync_interval` steps. - -`checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_interval`.""", - disabled=disabled, - on_change=on_change, - ) - - def _set_sync_interval(self): - st.number_input( - "Sync Interval", - key="sync_interval", - min_value=1, - help="""The step interval at which the `explorer` and `trainer` synchronize model weight.""", - ) - - def _set_sync_timeout(self): - st.number_input( - "Sync Timeout", - key="sync_timeout", - min_value=1, - help="The timeout value for the synchronization operation.", - ) - - def _set_runner_num(self): - st.number_input("Runner Num", key="runner_num", min_value=1) - - def _set_dtype(self): - st.selectbox("Dtype", ["float16", "bfloat16", "float32"], key="dtype") - - def _set_temperature(self): - st.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) - - def _set_top_p(self): - st.number_input("Top-p", key="top_p", min_value=0.0, max_value=1.0) - - def _set_top_k(self): - st.number_input( - "Top-k", - key="top_k", - min_value=-1, - max_value=512, - help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.", - ) - - def _set_seed(self): - st.number_input("Seed", key="seed", step=1) - - def _set_logprobs(self): - st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) - - def _set_use_v1(self): - st.checkbox("Use V1 Engine", key="use_v1") - - def _set_enable_prefix_caching(self): - st.checkbox("Prefix Caching", key="enable_prefix_caching") - - def _set_enforce_eager(self): - st.checkbox("Enforce Eager", key="enforce_eager") - - def _set_gpu_memory_utilization(self): - st.number_input( - "GPU Memory Utilization", key="gpu_memory_utilization", min_value=0.0, max_value=1.0 - ) - - def _set_enable_chunked_prefill(self): - st.checkbox("Chunked Prefill", key="enable_chunked_prefill") - - def _set_enable_thinking(self): - st.checkbox("Enable Thinking For Qwen3", key="enable_thinking") - - def _set_enable_openai_api(self): - st.checkbox("Enable OpenAI API", key="enable_openai_api") - - def _set_max_timeout(self): - st.number_input("Max Timeout", key="max_timeout", min_value=0) - - def _set_explorer_max_retry_times(self): - st.number_input("Explorer Max Retry Times", key="explorer_max_retry_times", min_value=0) - - def _set_trainer_type(self): - st.selectbox("Trainer Type", ["verl"], key="trainer_type") - - def _set_algorithm_type(self): - def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.PPO.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value - elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["mode"] = "train" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - else: # TODO: add more algorithms - pass - self._set_trainer_gpu_num() - - st.selectbox( - "Algorithm Type", - [ - AlgorithmType.PPO.value, - AlgorithmType.GRPO.value, - AlgorithmType.DPO.value, - AlgorithmType.OPMD.value, - ], - key="algorithm_type", - on_change=on_change, - ) - - def _set_sft_warmup_steps(self): - st.number_input("SFT Warmup Steps", key="sft_warmup_steps", min_value=0) - - def _set_eval_interval(self): - st.number_input("Eval Interval", key="eval_interval", min_value=1) - - def _set_eval_on_latest_checkpoint(self): - st.checkbox("Eval on Latest Checkpoint", key="eval_on_latest_ckp") - - def _set_training_args(self): - st.multiselect( - "Training Args", - [ - "balance_batch", - "gradient_checkpointing", - "remove_padding", - "dynamic_bsz", - ], - key="training_args", - ) - - def _set_save_interval(self): - if ( - st.session_state["algorithm_type"] == AlgorithmType.DPO.value - or st.session_state["sync_method"] == SyncMethod.NCCL.value - ): - st.session_state["save_interval"] = st.session_state["_nccl_save_interval"] - freeze_save_interval = False - else: - st.session_state["save_interval"] = st.session_state["sync_interval"] - freeze_save_interval = True - - def on_change(): - if ( - st.session_state["algorithm_type"] == AlgorithmType.DPO.value - or st.session_state["sync_method"] == SyncMethod.NCCL.value - ): - st.session_state["_nccl_save_interval"] = st.session_state["save_interval"] - - st.number_input( - "Save Interval", - key="save_interval", - min_value=1, - help="Set to `sync_interval` when `algorithm_type != DPO && sync_method == checkpoint`", - disabled=freeze_save_interval, - on_change=on_change, - ) - - def _set_ppo_epochs(self): - st.number_input("PPO Epochs", key="ppo_epochs", min_value=1) - - def _set_training_strategy(self): - st.selectbox( - "Training Strategy", - ["fsdp", "megatron"], - key="training_strategy", - help="megatron is not tested", - ) - - def _set_param_offload(self): - st.checkbox("FSDP Param Offload", key="param_offload") - - def _set_optimizer_offload(self): - st.checkbox("FSDP Optimizer Offload", key="optimizer_offload") - - def _set_resume_mode(self): - st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], key="resume_mode") - - def _set_resume_from_path(self): - if st.session_state["resume_mode"] == "resume_path": - st.text_input("Resume Path", key="resume_from_path") - if ( - not st.session_state["resume_from_path"].strip() - or "global_step_" not in st.session_state["resume_from_path"] - ): - self.unfinished_fields.add("resume_from_path") - st.warning("Please input a valid resume path when `resume_mode == resume_path`") - - def _set_critic_warmup(self): - st.number_input("Critic Warmup Steps", key="critic_warmup", min_value=0) - - def _set_total_training_steps(self): - st.number_input("Total Training Steps", key="total_training_steps", min_value=1) - - def _set_default_hdfs_dir(self): - st.text_input("Default HDFS Dir", key="default_hdfs_dir") - - def _set_remove_previous_ckpt_in_save(self): - st.checkbox("Remove Previous Checkpoint in Save", key="remove_previous_ckpt_in_save") - - def _set_del_local_ckpt_after_load(self): - st.checkbox("Delete Local Checkpoint After Load", key="del_local_ckpt_after_load") - - def _set_max_actor_ckpt_to_keep(self): - st.number_input("Max Actor Checkpoint to Keep", key="max_actor_ckpt_to_keep", min_value=1) - - def _set_max_critic_ckpt_to_keep(self): - st.number_input("Max Critic Checkpoint to Keep", key="max_critic_ckpt_to_keep", min_value=1) - - def _set_gamma(self): - st.number_input(r"Gamma :blue-badge[$\gamma$]", key="gamma") - - def _set_lam(self): - st.number_input(r"Lambda :blue-badge[$\lambda$]", key="lam") - - def _set_norm_adv_by_std_in_grpo(self): - st.checkbox("Norm Adv by Std in GRPO", key="norm_adv_by_std_in_grpo") - - def _set_use_kl_in_reward(self): - st.checkbox("Use KL in Reward", key="use_kl_in_reward") - - def _set_kl_penalty(self): - st.selectbox("KL Penalty", ["kl", "abs", "mse", "low_var_kl"], key="kl_penalty") - - def _set_kl_ctrl_type(self): - st.selectbox("KL Ctrl Type", ["fixed", "adaptive"], key="kl_ctrl_type") - - def _set_kl_ctrl_coef(self): - st.number_input("KL Ctrl Coef", key="kl_ctrl_coef", format="%.1e") - - def _set_horizon(self): - st.number_input("Horizon", key="horizon", min_value=1.0) - - def _set_target_kl(self): - st.number_input("Target KL", key="target_kl", format="%.1e") - - def _set_actor_ppo_micro_batch_size_per_gpu(self): - st.session_state["actor_ppo_micro_batch_size_per_gpu"] = min( - st.session_state["actor_ppo_micro_batch_size_per_gpu"], - st.session_state["_train_batch_size_per_gpu"], - ) - st.number_input( - "Micro Batch Size Per GPU :blue-badge[(Actor)]", - key="actor_ppo_micro_batch_size_per_gpu", - min_value=1, - max_value=st.session_state["_train_batch_size_per_gpu"], - ) - - def _set_ref_log_prob_micro_batch_size_per_gpu(self): - st.session_state["ref_log_prob_micro_batch_size_per_gpu"] = min( - st.session_state["ref_log_prob_micro_batch_size_per_gpu"], - st.session_state["_train_batch_size_per_gpu"], - ) - st.number_input( - "Micro Batch Size Per GPU :blue-badge[(Ref)]", - key="ref_log_prob_micro_batch_size_per_gpu", - min_value=1, - max_value=st.session_state["_train_batch_size_per_gpu"], - ) - - def _set_actor_ulysses_sequence_parallel_size(self): - st.number_input( - "Ulysses Sequence Parallel Size", - key="actor_ulysses_sequence_parallel_size", - min_value=1, - max_value=8, - ) - - def _set_actor_lr(self): - st.number_input( - "Learning Rate :blue-badge[(Actor)]", - key="actor_lr", - min_value=1e-7, - max_value=1e-3, - format="%.1e", - ) - - def _set_actor_warmup_style(self): - st.selectbox( - "LR Warmup Style :blue-badge[(Actor)]", - ["constant", "cosine"], - key="actor_warmup_style", - ) - - def _set_actor_lr_warmup_steps_ratio(self): - st.number_input( - "LR Warmup Steps Ratio :blue-badge[(Actor)]", - key="actor_lr_warmup_steps_ratio", - min_value=0.0, - max_value=1.0, - ) - - def _set_actor_grad_clip(self): - st.number_input( - "Grad Clip :blue-badge[(Actor)]", - key="actor_grad_clip", - min_value=0.0, - max_value=1.0, - help="Clipping by Norm", - ) - - def _set_actor_clip_ratio(self): - st.number_input( - r"Clip Ratio :blue-badge[$\epsilon$]", - key="actor_clip_ratio", - min_value=0.0, - max_value=1.0, - ) - - def _set_actor_entropy_coef(self): - st.number_input( - "Entropy Coeff", - key="actor_entropy_coef", - min_value=0.0, - max_value=1.0, - format="%.1e", - ) - - def _set_actor_use_kl_loss(self): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["actor_use_kl_loss"] = True - else: - st.session_state["actor_use_kl_loss"] = st.session_state["_not_dpo_actor_use_kl_loss"] - - def on_change(): - st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[ - "actor_use_kl_loss" - ] - - st.checkbox("Use KL Loss", key="actor_use_kl_loss", on_change=on_change) - - def _set_actor_kl_loss_coef(self): - st.number_input( - r"KL Loss Coef :blue-badge[$\beta$]", - key="actor_kl_loss_coef", - min_value=0.0, - max_value=1.0, - format="%.1e", - ) - - def _set_actor_kl_loss_type(self): - st.selectbox( - "KL Loss Type", - ["kl", "abs", "mse", "low_var_kl"], - key="actor_kl_loss_type", - ) - - def _set_actor_tau(self): - st.number_input( - "Tau for OPMD", - key="actor_tau", - min_value=0.0, - format="%.1e", - ) - - def _set_actor_opmd_baseline(self): - st.selectbox( - "OPMD Baseline", - ["mean", "logavgexp"], - key="actor_opmd_baseline", - ) - - def _set_actor_use_uid(self): - st.checkbox("Use UID for OPMD", key="actor_use_uid") - - def _set_actor_checkpoint(self): - st.multiselect( - "Checkpoint", - ["model", "hf_model", "optimizer", "extra"], - key="actor_checkpoint", - ) - - def _set_critic_ppo_micro_batch_size_per_gpu(self): - st.session_state["critic_ppo_micro_batch_size_per_gpu"] = min( - st.session_state["critic_ppo_micro_batch_size_per_gpu"], - st.session_state["_train_batch_size_per_gpu"], - ) - st.number_input( - "Micro Batch Size Per GPU :blue-badge[(Critic)]", - key="critic_ppo_micro_batch_size_per_gpu", - min_value=1, - max_value=st.session_state["_train_batch_size_per_gpu"], - ) - - def _set_critic_ulysses_sequence_parallel_size(self): - st.number_input( - "Ulysses Sequence Parallel Size", - key="critic_ulysses_sequence_parallel_size", - min_value=1, - max_value=8, - ) - - def _set_critic_lr(self): - st.number_input( - "Learning Rate :blue-badge[(Critic)]", - key="critic_lr", - min_value=1e-7, - max_value=1e-3, - format="%.1e", - ) - - def _set_critic_warmup_style(self): - st.selectbox( - "LR Warmup Style :blue-badge[(Critic)]", - ["constant", "cosine"], - key="critic_warmup_style", - ) - - def _set_critic_lr_warmup_steps_ratio(self): - st.number_input( - "LR Warmup Steps Ratio :blue-badge[(Critic)]", - key="critic_lr_warmup_steps_ratio", - min_value=0.0, - max_value=1.0, - ) - - def _set_critic_grad_clip(self): - st.number_input( - "Grad Clip :blue-badge[(Critic)]", - key="critic_grad_clip", - min_value=0.0, - max_value=1.0, - help="Clipping by Norm", - ) - - def _set_critic_cliprange_value(self): - st.number_input( - "Cliprange Value", - key="critic_cliprange_value", - min_value=0.0, - max_value=1.0, - ) - - def _set_critic_checkpoint(self): - st.multiselect( - "Checkpoint", - ["model", "hf_model", "optimizer", "extra"], - key="critic_checkpoint", - ) - - def _set_configs_with_st_columns( - self, config_names: List[str], columns_config: List[int] = None - ): - if columns_config is None: - columns_config = len(config_names) - columns = st.columns(columns_config) - for col, config_name in zip(columns, config_names): - with col: - getattr(self, f"_set_{config_name}")() + last_idx, del_num = 0, 0 + for idx in range(st.session_state["_auxiliary_models_num"]): + if st.session_state.get(f"auxiliary_model_{idx}_del_flag", False): + del_num += 1 + continue + for key in auxiliary_model_keys: + full_key = f"auxiliary_model_{idx}_{key}" + last_full_key = f"auxiliary_model_{last_idx}_{key}" + st.session_state[last_full_key] = st.session_state[full_key] + last_idx += 1 + st.session_state["_auxiliary_models_num"] -= del_num + + def get_configs(self, *config_names: str, columns_spec: List[int] = None): + CONFIG_GENERATORS.get_configs(*config_names, columns_spec=columns_spec) def beginner_mode(self): st.header("Essential Configs") - self._set_configs_with_st_columns(["project", "exp_name"], columns_config=[1, 3]) + self.get_configs("project", "exp_name", columns_spec=[1, 2]) - self._set_model_path() + self.get_configs("model_path") - self._set_checkpoint_root_dir() + self.get_configs("checkpoint_root_dir") - self._set_taskset_path() + if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + self.get_configs("taskset_path") + else: + self.get_configs("experience_buffer_path") - self._set_configs_with_st_columns(["algorithm_type", "sft_warmup_steps", "monitor_type"]) + self.get_configs("algorithm_type", "sft_warmup_steps", "monitor_type") if st.session_state["sft_warmup_steps"] > 0: - self._set_sft_warmup_dataset_path() + self.get_configs("sft_warmup_dataset_path") st.header("Important Configs") - self._set_configs_with_st_columns( - ["node_num", "gpu_per_node", "engine_num", "tensor_parallel_size"] - if st.session_state["mode"] == "both" - else ["node_num", "gpu_per_node"] - ) - self._check_engine_num_and_tp_size() + self.get_configs("node_num", "gpu_per_node", "engine_num", "tensor_parallel_size") - self._set_configs_with_st_columns( - ["total_epochs", "train_batch_size", "ppo_epochs", "repeat_times"] - if st.session_state["mode"] == "both" - else ["total_epochs", "train_batch_size", "ppo_epochs"] - ) - self._check_train_batch_size() + self.get_configs("total_epochs", "train_batch_size", "ppo_epochs", "repeat_times") - self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"]) + self.get_configs("storage_type", "max_prompt_tokens", "max_response_tokens") - self._set_configs_with_st_columns( - ["sync_interval", "eval_interval", "save_interval"] - if st.session_state["mode"] == "both" - else ["eval_interval", "save_interval"] - ) + self.get_configs("sync_interval", "eval_interval", "save_interval") if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: - self._set_taskset_args() + self.get_configs("taskset_args") else: - self._set_dpo_dataset_kwargs() + self.get_configs("dpo_dataset_kwargs") if st.session_state["sft_warmup_steps"] > 0: - self._set_sft_warmup_dataset_args() + self.get_configs("sft_warmup_dataset_args") - self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"]) + self.get_configs("default_workflow_type", "default_reward_fn_type") - self._set_actor_use_kl_loss() - if st.session_state["actor_use_kl_loss"]: - self._set_configs_with_st_columns(["actor_kl_loss_coef", "actor_kl_loss_type"]) + self.get_configs("actor_use_kl_loss") + self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type") - self._set_configs_with_st_columns( - [ - "actor_ppo_micro_batch_size_per_gpu", - "actor_lr", - "ref_log_prob_micro_batch_size_per_gpu", - ] + self.get_configs( + "actor_ppo_micro_batch_size_per_gpu", + "actor_lr", + "ref_log_prob_micro_batch_size_per_gpu", ) - use_critic = ( - st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value - ) # TODO: may apply to expert mode - if use_critic: - self._set_configs_with_st_columns(["critic_ppo_micro_batch_size_per_gpu", "critic_lr"]) + self.get_configs("critic_ppo_micro_batch_size_per_gpu", "critic_lr") def _expert_model_part(self): - self._set_configs_with_st_columns(["project", "exp_name"], columns_config=[1, 3]) + self.get_configs("project", "exp_name", columns_spec=[1, 2]) - self._set_model_path() - self._set_critic_model_path() + self.get_configs("model_path") + self.get_configs("critic_model_path") - self._set_checkpoint_root_dir() + self.get_configs("checkpoint_root_dir") - self._set_configs_with_st_columns(["monitor_type", "node_num", "gpu_per_node"]) - self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"]) + self.get_configs("monitor_type", "node_num", "gpu_per_node") + self.get_configs("max_prompt_tokens", "max_response_tokens") def _expert_buffer_part(self): - self._set_configs_with_st_columns(["total_epochs", "train_batch_size"]) - self._check_train_batch_size() + self.get_configs("total_epochs", "train_batch_size") - self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"]) - self._set_system_prompt() - self._set_reply_prefix() + self.get_configs("default_workflow_type", "default_reward_fn_type") + self.get_configs("system_prompt") + self.get_configs("reply_prefix") if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: with st.expander("Taskset Configs", expanded=True): - self._set_taskset_path() - self._set_taskset_args() + self.get_configs("taskset_path") + self.get_configs("taskset_args") else: with st.expander("DPO Dataset Configs", expanded=True): - self._set_experience_buffer_path() - self._set_dpo_dataset_kwargs() + self.get_configs("experience_buffer_path") + self.get_configs("storage_type") + self.get_configs("dpo_dataset_kwargs") with st.expander("Eval Tasksets Configs", expanded=True): - self._set_eval_tasksets() + self.get_configs("eval_tasksets") with st.expander("SFT Dataset Configs"): - self._set_sft_warmup_dataset_path() - self._set_sft_warmup_dataset_args() + self.get_configs("sft_warmup_dataset_path") + self.get_configs("sft_warmup_dataset_args") if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: with st.expander("Experiences Buffer Configs", expanded=True): - self._set_storage_type() - self._set_experience_buffer_path() + self.get_configs("storage_type") + self.get_configs("experience_buffer_path") self.buffer_advanced_tab = st.expander("Advanced Config") with self.buffer_advanced_tab: - self._set_configs_with_st_columns(["buffer_max_retry_times", "max_retry_interval"]) + self.get_configs("buffer_max_retry_times", "max_retry_interval") def _expert_explorer_part(self): - self._set_configs_with_st_columns(["sync_method", "sync_interval", "sync_timeout"]) - - self._set_configs_with_st_columns( - [ - "runner_num", - "max_timeout", - "explorer_max_retry_times", - ] - ) + self.get_configs("sync_method", "sync_interval", "sync_timeout") - self._set_configs_with_st_columns(["eval_interval", "eval_on_latest_checkpoint"]) + self.get_configs("runner_num", "max_timeout", "explorer_max_retry_times", "eval_interval") + + self.get_configs("eval_on_latest_checkpoint") with st.expander("Rollout Model Config", expanded=True): - self._set_configs_with_st_columns(["engine_type", "engine_num", "tensor_parallel_size"]) - self._check_engine_num_and_tp_size() + self.get_configs("engine_type", "engine_num", "tensor_parallel_size") - self._set_configs_with_st_columns(["gpu_memory_utilization", "dtype", "seed"]) + self.get_configs("gpu_memory_utilization", "dtype", "seed") - self._set_configs_with_st_columns( - ["use_v1", "enforce_eager", "enable_prefix_caching", "enable_chunked_prefill"] + self.get_configs( + "use_v1", "enforce_eager", "enable_prefix_caching", "enable_chunked_prefill" ) - self._set_configs_with_st_columns(["enable_thinking", "enable_openai_api"]) + self.get_configs("enable_thinking", "enable_openai_api") - with st.expander("Auxiliary Models", expanded=True): # TODO - pass + with st.expander("Auxiliary Models", expanded=True): + self.get_configs("auxiliary_models") def _expert_trainer_part(self): - self._set_configs_with_st_columns(["algorithm_type", "gamma", "lam"]) - self._set_configs_with_st_columns(["repeat_times", "save_interval"]) - self._check_sft_warmup_dataset_path() + self.get_configs("algorithm_type", "gamma", "lam") + self.get_configs("repeat_times", "save_interval") + self.get_configs("enable_preview") if st.session_state["trainer_type"] == "verl": self._expert_verl_trainer_part() - def _expert_verl_trainer_part(self): - rl_training_tab, rl_algorithm_tab, actor_ref_tab, critic_tab = st.tabs( - [ - "RL Training Config", - "RL Algorithm Config", - "Actor and Ref Config", - "Critic Config", - ] - ) - with rl_training_tab: - st.subheader("RL Training Config") - self._set_training_args() + def _expert_verl_training_part(self): + st.subheader("RL Training Config") + self.get_configs("training_args") - self._set_configs_with_st_columns(["ppo_epochs", "training_strategy", "resume_mode"]) + self.get_configs("ppo_epochs", "training_strategy", "resume_mode") - if st.session_state["training_strategy"] == "fsdp": - self._set_configs_with_st_columns(["param_offload", "optimizer_offload"]) - self._set_resume_from_path() + self.get_configs("param_offload", "optimizer_offload") + self.get_configs("resume_from_path") - with st.expander("Advanced Config"): - self._set_configs_with_st_columns(["critic_warmup", "total_training_steps"]) + with st.expander("Advanced Config"): + self.get_configs("critic_warmup", "total_training_steps") - self._set_default_hdfs_dir() + self.get_configs("default_hdfs_dir") - self._set_configs_with_st_columns( - ["remove_previous_ckpt_in_save", "del_local_ckpt_after_load"] - ) + self.get_configs("remove_previous_ckpt_in_save", "del_local_ckpt_after_load") - self._set_configs_with_st_columns( - ["max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep"] - ) + self.get_configs("max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep") - with rl_algorithm_tab: - st.subheader("RL Algorithm Config") - self._set_configs_with_st_columns(["norm_adv_by_std_in_grpo", "use_kl_in_reward"]) - self._set_configs_with_st_columns(["kl_penalty", "kl_ctrl_type", "kl_ctrl_coef"]) - self._set_configs_with_st_columns(["horizon", "target_kl"]) + def _expert_verl_algorithm_part(self): + st.subheader("RL Algorithm Config") + self.get_configs("norm_adv_by_std_in_grpo", "use_kl_in_reward") + self.get_configs("kl_penalty", "kl_ctrl_type", "kl_ctrl_coef") + self.get_configs("horizon", "target_kl") - with actor_ref_tab: - st.subheader("Actor Model Config") - self._set_configs_with_st_columns( - [ - "actor_ppo_micro_batch_size_per_gpu", - "ref_log_prob_micro_batch_size_per_gpu", - "actor_ulysses_sequence_parallel_size", - ] - ) + def _expert_verl_actor_part(self): + st.subheader("Actor Model Config") + self.get_configs( + "actor_ppo_micro_batch_size_per_gpu", + "ref_log_prob_micro_batch_size_per_gpu", + "actor_ulysses_sequence_parallel_size", + ) - self._set_configs_with_st_columns( - ["actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio"] - ) + self.get_configs("actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio") - self._set_configs_with_st_columns( - ["actor_grad_clip", "actor_clip_ratio", "actor_entropy_coef"] - ) + self.get_configs("actor_grad_clip", "actor_clip_ratio", "actor_entropy_coef") - self._set_actor_use_kl_loss() - if st.session_state["actor_use_kl_loss"]: - self._set_configs_with_st_columns(["actor_kl_loss_coef", "actor_kl_loss_type"]) + self.get_configs("actor_use_kl_loss", "actor_use_uid") + self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type") - if st.session_state["algorithm_type"] == "opmd": - self._set_configs_with_st_columns( - ["actor_tau", "actor_opmd_baseline", "actor_use_uid"] - ) + self.get_configs("actor_tau", "actor_opmd_baseline") - self._set_actor_checkpoint() + self.get_configs("actor_checkpoint") - with critic_tab: - st.subheader("Critic Model Config") - self._set_configs_with_st_columns( - ["critic_ppo_micro_batch_size_per_gpu", "critic_ulysses_sequence_parallel_size"] - ) + def _expert_verl_critic_part(self): + st.subheader("Critic Model Config") + self.get_configs( + "critic_ppo_micro_batch_size_per_gpu", "critic_ulysses_sequence_parallel_size" + ) - self._set_configs_with_st_columns( - ["critic_lr", "critic_warmup_style", "critic_lr_warmup_steps_ratio"] - ) + self.get_configs("critic_lr", "critic_warmup_style", "critic_lr_warmup_steps_ratio") + + self.get_configs("critic_grad_clip", "critic_cliprange_value") + self.get_configs("critic_checkpoint") + + def _expert_verl_trainer_part(self): + name2func = { + "RL Training Config": self._expert_verl_training_part, + "RL Algorithm Config": self._expert_verl_algorithm_part, + "Actor and Ref Config": self._expert_verl_actor_part, + } + if use_critic(): + name2func["Critic Config"] = self._expert_verl_critic_part - self._set_configs_with_st_columns(["critic_grad_clip", "critic_cliprange_value"]) - self._set_critic_checkpoint() + tabs = st.tabs([name for name in name2func]) + for tab, func in zip(tabs, name2func.values()): + with tab: + func() def expert_mode(self): tab2func = { @@ -1455,7 +385,6 @@ def _generate_verl_config(self): }, "trainer": { "balance_batch": balance_batch, - "logger": ["tensorboard"], "resume_mode": st.session_state["resume_mode"], "resume_from_path": st.session_state["resume_from_path"], "default_hdfs_dir": st.session_state["default_hdfs_dir"], @@ -1467,7 +396,7 @@ def _generate_verl_config(self): }, } - if st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value: + if use_critic(): trainer_config["trainer"]["critic_warmup"] = st.session_state["critic_warmup"] trainer_config["critic"] = { "strategy": st.session_state["training_strategy"], @@ -1510,8 +439,8 @@ def _generate_verl_config(self): return trainer_config def _gen_buffer_config(self): + experience_buffer_path = st.session_state["experience_buffer_path"].strip() if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: - experience_buffer_path = st.session_state["experience_buffer_path"].strip() if ( not experience_buffer_path and st.session_state["storage_type"] == StorageType.SQL.value @@ -1527,7 +456,20 @@ def _gen_buffer_config(self): buffer_config = { "batch_size": st.session_state["train_batch_size"], "total_epochs": st.session_state["total_epochs"], - "explorer_input": { + "trainer_input": { + "experience_buffer": { + "name": "experience_buffer", + "storage_type": st.session_state["storage_type"], + "path": experience_buffer_path, + }, + "sft_warmup_steps": st.session_state["sft_warmup_steps"], + }, + "max_retry_times": st.session_state["buffer_max_retry_times"], + "max_retry_interval": st.session_state["max_retry_interval"], + } + + if st.session_state["mode"] != "train": + buffer_config["explorer_input"] = { "taskset": { "name": "taskset", "storage_type": StorageType.FILE.value, @@ -1548,31 +490,19 @@ def _gen_buffer_config(self): "default_reward_fn_type": st.session_state["default_reward_fn_type"], "system_prompt": st.session_state["system_prompt"], "reply_prefix": st.session_state["reply_prefix"], - }, - "trainer_input": { - "experience_buffer": { - "name": "experience_buffer", - "storage_type": st.session_state["storage_type"], - "path": experience_buffer_path, - }, - "sft_warmup_steps": st.session_state["sft_warmup_steps"], - }, - "max_retry_times": st.session_state["buffer_max_retry_times"], - "max_retry_interval": st.session_state["max_retry_interval"], - } - - for idx in range(st.session_state["_eval_tasksets_num"]): - if st.session_state[f"eval_taskset_{idx}_path"].strip(): - buffer_config["explorer_input"]["eval_tasksets"].append( - { - "name": st.session_state[f"eval_taskset_{idx}_name"], - "path": st.session_state[f"eval_taskset_{idx}_path"], - "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"], - "split": st.session_state[f"eval_taskset_{idx}_split"], - "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"], - "response_key": st.session_state[f"eval_taskset_{idx}_response_key"], - } - ) + } + for idx in range(st.session_state["_eval_tasksets_num"]): + if st.session_state[f"eval_taskset_{idx}_path"].strip(): + buffer_config["explorer_input"]["eval_tasksets"].append( + { + "name": st.session_state[f"eval_taskset_{idx}_name"], + "path": st.session_state[f"eval_taskset_{idx}_path"], + "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"], + "split": st.session_state[f"eval_taskset_{idx}_split"], + "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"], + "response_key": st.session_state[f"eval_taskset_{idx}_response_key"], + } + ) if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: experience_buffer = buffer_config["trainer_input"]["experience_buffer"] experience_buffer["split"] = st.session_state["dpo_dataset_train_split"] @@ -1676,7 +606,7 @@ def generate_config(self): "trainer": { "trainer_type": st.session_state["trainer_type"], "save_interval": st.session_state["save_interval"], - "enable_preview": True, # TODO + "enable_preview": st.session_state["enable_preview"], "actor_use_kl_loss": st.session_state["actor_use_kl_loss"], "actor_kl_loss_coef": st.session_state["actor_kl_loss_coef"], "actor_entropy_coef": st.session_state["actor_entropy_coef"], @@ -1694,7 +624,7 @@ def generate_config(self): }, } - if st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value: + if use_critic(): config["model"]["critic_model_path"] = ( st.session_state["critic_model_path"].strip() if st.session_state["critic_model_path"].strip() diff --git a/trinity/manager/config_registry/__init__.py b/trinity/manager/config_registry/__init__.py new file mode 100644 index 0000000000..e62c565fb4 --- /dev/null +++ b/trinity/manager/config_registry/__init__.py @@ -0,0 +1,13 @@ +import trinity.manager.config_registry.buffer_config_manager as buffer_config_manager +import trinity.manager.config_registry.explorer_config_manager as explorer_config_manager +import trinity.manager.config_registry.model_config_manager as model_config_manager +import trinity.manager.config_registry.trainer_config_manager as trainer_config_manager +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS + +__all__ = [ + "CONFIG_GENERATORS", + "buffer_config_manager", + "explorer_config_manager", + "model_config_manager", + "trainer_config_manager", +] diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py new file mode 100644 index 0000000000..044f982e94 --- /dev/null +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -0,0 +1,433 @@ +import streamlit as st + +from trinity.common.constants import AlgorithmType, PromptType, StorageType +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS +from trinity.common.workflows.workflow import WORKFLOWS +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS + + +@CONFIG_GENERATORS.register_config(default_value=20) +def set_total_epochs(**kwargs): + st.number_input("Total Epochs", min_value=1, **kwargs) + + +def _str_for_train_batch_size(): + trainer_gpu_num_str = ( + "`gpu_per_node * node_num - engine_num * tensor_parallel_size`" + if st.session_state["mode"] == "both" + else "`gpu_per_node * node_num`" + ) + return ( + f"Please ensure that `train_batch_size` can be divided by " + f"{trainer_gpu_num_str} = {st.session_state['trainer_gpu_num']}." + ) + + +@CONFIG_GENERATORS.register_config( + default_value=96, + visible=lambda: st.session_state["trainer_gpu_num"] > 0, + other_configs={"_train_batch_size_per_gpu": 16}, +) +def set_train_batch_size(**kwargs): + key = kwargs.get("key") + trainer_gpu_num = st.session_state["trainer_gpu_num"] + st.session_state[key] = ( + st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"] + ) + + def on_change(): + st.session_state["_train_batch_size_per_gpu"] = max( + st.session_state[key] // st.session_state["trainer_gpu_num"], 1 + ) + + st.number_input( + "Train Batch Size", + min_value=trainer_gpu_num, + step=trainer_gpu_num, + help=_str_for_train_batch_size(), + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_check() +def check_train_batch_size(unfinished_fields: set, key: str): + if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0: + unfinished_fields.add(key) + st.warning(_str_for_train_batch_size()) + + +@CONFIG_GENERATORS.register_config(default_value=3) +def set_buffer_max_retry_times(**kwargs): + st.number_input("Max Retry Times", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_max_retry_interval(**kwargs): + st.number_input("Max Retry Interval", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="") +def set_taskset_path(**kwargs): + st.text_input("Taskset Path", **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_taskset_path(unfinished_fields: set, key: str): + if not st.session_state[key].strip(): + unfinished_fields.add(key) + st.warning("Please input taskset path.") + + +# def _set_temperature(self): +# st.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) + +# def _set_top_p(self): +# st.number_input("Top-p", key="top_p", min_value=0.0, max_value=1.0) + +# def _set_top_k(self): +# st.number_input( +# "Top-k", +# key="top_k", +# min_value=-1, +# max_value=512, +# help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.", +# ) + +# def _set_logprobs(self): +# st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) + + +@CONFIG_GENERATORS.register_config( + visible=lambda: st.session_state["taskset_path"] + and "://" not in st.session_state["taskset_path"], + other_configs={ + "taskset_subset_name": None, + "taskset_split": "train", + "taskset_prompt_key": "question", + "taskset_response_key": "answer", + "temperature": 1.0, + "top_p": 1.0, # TODO: to be used + "top_k": -1, # TODO: to be used + "logprobs": 0, + }, +) +def set_taskset_args(**kwargs): + subset_name_col, split_col = st.columns(2) + subset_name_col.text_input( + "Subset Name :orange-badge[(Needs review)]", + key="taskset_subset_name", + help="The subset name used for `datasets.load_datasets`, see " + "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", + ) + split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split") + prompt_key_col, response_key_col = st.columns(2) + prompt_key_col.text_input("Prompt Key :orange-badge[(Needs review)]", key="taskset_prompt_key") + response_key_col.text_input( + "Response Key :orange-badge[(Needs review)]", key="taskset_response_key" + ) + # self._set_configs_with_st_columns(["temperature", "logprobs"]) + temperature_col, logprobs_col = st.columns(2) + temperature_col.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) + logprobs_col.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) + + +def _set_eval_taskset_idx(idx): + col1, col2 = st.columns([9, 1]) + col1.text_input( + "Taskset Name", + key=f"eval_taskset_{idx}_name", + ) + if col2.button("✖️", key=f"eval_taskset_{idx}_del_flag", type="primary"): + st.rerun() + st.text_input( + "Eval Taskset Path", + key=f"eval_taskset_{idx}_path", + ) + if not st.session_state[f"eval_taskset_{idx}_path"].strip(): + st.warning("Please input the taskset path, or it will be ignored.") + subset_name_col, split_col = st.columns(2) + subset_name_col.text_input( + "Subset Name :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_subset_name", + help="The subset name used for `datasets.load_datasets`, see " + "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", + ) + split_col.text_input( + "Eval Split :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_split", + ) + prompt_key_col, response_key_col = st.columns(2) + prompt_key_col.text_input( + "Prompt Key :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_prompt_key", + ) + response_key_col.text_input( + "Response Key :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_response_key", + ) + + temperature_col, logprobs_col, n_col = st.columns(3) + temperature_col.number_input( + "Temperature", + key=f"eval_taskset_{idx}_temperature", + min_value=0.0, + max_value=1.0, + ) + logprobs_col.number_input( + "Logprobs", + key=f"eval_taskset_{idx}_logprobs", + min_value=0, + max_value=20, + ) + n_col.number_input( + "Eval repeat times", + key=f"eval_taskset_{idx}_n", + min_value=1, + max_value=20, + ) + + +@CONFIG_GENERATORS.register_config(other_configs={"_eval_tasksets_num": 0}) +def set_eval_tasksets(**kwargs): + if st.button("Add Eval Taskset"): + idx = st.session_state["_eval_tasksets_num"] + st.session_state[f"eval_taskset_{idx}_split"] = "test" + st.session_state[f"eval_taskset_{idx}_prompt_key"] = "prompt" + st.session_state[f"eval_taskset_{idx}_response_key"] = "response" + st.session_state[f"eval_taskset_{idx}_temperature"] = 0.1 + st.session_state["_eval_tasksets_num"] += 1 + if st.session_state["_eval_tasksets_num"] > 0: + tabs = st.tabs( + [f"Eval Taskset {i + 1}" for i in range(st.session_state["_eval_tasksets_num"])] + ) + for idx, tab in enumerate(tabs): + with tab: + _set_eval_taskset_idx(idx) + + +@CONFIG_GENERATORS.register_config(default_value="math_workflow") +def set_default_workflow_type(**kwargs): + st.selectbox( + "Default Workflow Type :orange-badge[(Needs review)]", + WORKFLOWS.modules.keys(), + help=r"""`simple_workflow`: call 'model.chat()' to get responses. + +`math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses. + +Other workflows: conduct multi-turn task for the given dataset. +""", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value="math_reward") +def set_default_reward_fn_type(**kwargs): + st.selectbox( + "Default Reward Fn Type :orange-badge[(Needs review)]", + REWARD_FUNCTIONS.modules.keys(), + help=r"""`accuracy_reward`: check the accuracy for math problems. + +`format_reward`: check if the response matches the format (default: `** *`). + +`math_reward`: `accuracy_reward` (1 or 0) + `format_reward` (+0.1 or -0.1). +""", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_system_prompt(**kwargs): + st.text_area( + "System Prompt", + placeholder="System prompt is used to guide the model behavior.", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_reply_prefix(**kwargs): + st.text_area( + "Assistant Reply Prefix", + placeholder="""Assistant reply prefix is used to specify the initial content of model reply, """ + """and a common setting is: \nLet me solve this step by step. """, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=StorageType.QUEUE.value, + other_configs={ + "_dpo_storage_type": StorageType.FILE.value, + "_not_dpo_storage_type": StorageType.QUEUE.value, + }, +) +def set_storage_type(**kwargs): + key = kwargs.get("key") + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state[key] = st.session_state["_dpo_storage_type"] + storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] + else: + st.session_state[key] = st.session_state["_not_dpo_storage_type"] + storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value] + + def on_change(): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["_dpo_storage_type"] = st.session_state[key] + else: + st.session_state["_not_dpo_storage_type"] = st.session_state[key] + + st.selectbox( + "Storage Type", + storage_candidates, + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value="", + other_configs={ + "_dpo_experience_buffer_path": "", + "_not_dpo_experience_buffer_path": "", + }, +) +def set_experience_buffer_path(**kwargs): # TODO + key = kwargs.get("key") + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["taskset_path"] and not st.session_state["_dpo_experience_buffer_path"]: + st.session_state["_dpo_experience_buffer_path"] = st.session_state["taskset_path"] + st.session_state[key] = st.session_state["_dpo_experience_buffer_path"] + title = "DPO Dataset Path" + help_msg = r"""This path to DPO dataset, + +if `storage_type == StorageType.FILE`, this should be a path to a file, + +if `storage_type == StorageType.SQL`, this should be a path to database.""" + else: + st.session_state[key] = st.session_state["_not_dpo_experience_buffer_path"] + title = "Experience Buffer Path" + help_msg = r"""This path is used for `trainer`, + +if `storage_type == StorageType.QUEUE`, default to `None`, + +if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_root_dir, '.cache', project_name, experiment_name)}/data.db`.""" + + def on_change(): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["_dpo_experience_buffer_path"] = st.session_state[key] + else: + st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[key] + + st.text_input(title, help=help_msg, on_change=on_change, **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_experience_buffer_path(unfinished_fields: set, key: str): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if not st.session_state[key].strip(): + unfinished_fields.add(key) + st.warning("Please input DPO dataset path.") + + +@CONFIG_GENERATORS.register_config( + other_configs={ + "dpo_dataset_train_split": "train", + "dpo_dataset_prompt_type": PromptType.MESSAGES.value, + "dpo_dataset_prompt_key": "prompt", + "dpo_dataset_chosen_key": "chosen", + "dpo_dataset_rejected_key": "rejected", + } +) +def set_dpo_dataset_kwargs(**kwargs): + dpo_dataset_train_split_col, dpo_dataset_prompt_type_col = st.columns(2) + dpo_dataset_train_split_col.text_input( + "DPO Dataset Train Split :orange-badge[(Needs review)]", key="dpo_dataset_train_split" + ) + dpo_dataset_prompt_type_col.selectbox( + "DPO Dataset Prompt Type :orange-badge[(Needs review)]", + [prompt_type.value for prompt_type in PromptType], + key="dpo_dataset_prompt_type", + ) + + ( + dpo_dataset_prompt_key_col, + dpo_dataset_chosen_key_col, + dpo_dataset_rejected_key_col, + ) = st.columns(3) + dpo_dataset_prompt_key_col.text_input( + "DPO Dataset Prompt Key :orange-badge[(Needs review)]", key="dpo_dataset_prompt_key" + ) + dpo_dataset_chosen_key_col.text_input( + "DPO Dataset Chosen Key :orange-badge[(Needs review)]", key="dpo_dataset_chosen_key" + ) + dpo_dataset_rejected_key_col.text_input( + "DPO Dataset Rejected Key :orange-badge[(Needs review)]", + key="dpo_dataset_rejected_key", + ) + + +@CONFIG_GENERATORS.register_config(default_value="") +def set_sft_warmup_dataset_path(**kwargs): + st.text_input("SFT Warmup Dataset Path", **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_sft_warmup_dataset_path(unfinished_fields: set, key: str): + if st.session_state["sft_warmup_steps"]: + if not st.session_state[key].strip(): + unfinished_fields.add(key) + st.warning("Please input SFT warmup dataset path when `sft_warmup_steps` is not 0") + + +@CONFIG_GENERATORS.register_config( + visible=lambda: st.session_state["sft_warmup_dataset_path"] + and "://" not in st.session_state["sft_warmup_dataset_path"], + other_configs={ + "sft_warmup_train_split": "train", + "sft_warmup_prompt_type": PromptType.MESSAGES.value, + "sft_warmup_messages_key": "messages", + "sft_warmup_prompt_key": "prompt", + "sft_warmup_response_key": "response", + }, +) +def set_sft_warmup_dataset_args(**kwargs): + ( + sft_warmup_train_split_col, + sft_warmup_prompt_type_col, + ) = st.columns(2) + sft_warmup_train_split_col.text_input( + "SFT Dataset Train Split :orange-badge[(Needs review)]", + key="sft_warmup_train_split", + ) + sft_warmup_prompt_type_col.selectbox( + "SFT Dataset Prompt Type :orange-badge[(Needs review)]", + [prompt_type.value for prompt_type in PromptType], + key="sft_warmup_prompt_type", + ) + ( + sft_warmup_messages_key_col, + sft_warmup_prompt_key_col, + sft_warmup_response_key_col, + ) = st.columns( + 3 + ) # TODO: select by prompt type + sft_warmup_messages_key_col.text_input( + "SFT Dataset Messages Key :orange-badge[(Needs review)]", + key="sft_warmup_messages_key", + ) + sft_warmup_prompt_key_col.text_input( + "SFT Dataset Prompt Key :orange-badge[(Needs review)]", key="sft_warmup_prompt_key" + ) + sft_warmup_response_key_col.text_input( + "SFT Dataset Response Key :orange-badge[(Needs review)]", + key="sft_warmup_response_key", + ) + + +# TODO: read_experience_strategy + + +@CONFIG_GENERATORS.register_config(default_value=0) +def set_sft_warmup_steps(**kwargs): + st.number_input("SFT Warmup Steps", min_value=0, **kwargs) diff --git a/trinity/manager/config_registry/config_registry.py b/trinity/manager/config_registry/config_registry.py new file mode 100644 index 0000000000..3b621a2de2 --- /dev/null +++ b/trinity/manager/config_registry/config_registry.py @@ -0,0 +1,209 @@ +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Set + +import streamlit as st + +from trinity.utils.registry import Registry + + +class ConfigRegistry(Registry): + """ + A registry for managing configuration settings and their associated functions. + """ + + def __init__(self, name: str): + super().__init__(name) + self._default_config = {} # Stores default values for configs + self._config_visibles = {} # Stores visibles for config visibility + self.unfinished_fields = set() + + def set_unfinished_fields(self, unfinished_fields: set): + """ + Set the unfinished fields to track incomplete configurations. + + Args: + unfinished_fields (set): Set of field names that are not yet configured. + """ + self.unfinished_fields = unfinished_fields + + @property + def default_config(self) -> dict: + """ + Get the dictionary of default configuration values. + """ + return self._default_config + + def get(self, config_name: str): + """ + Retrieve a configuration function if its visible is met (if any). + + Args: + config_name (str): Name of the configuration to retrieve. + + Returns: + The configuration function if visibles are met, else None. + """ + if config_name in self._config_visibles: + if not self._config_visibles[config_name](): + return None + return super().get(config_name) + + def get_check_func(self, config_name: str): + """ + Get the check function associated with a configuration. + + Args: + config_name (str): Name of the configuration. + + Returns: + The check function for the specified configuration. + """ + check_func_name = f"check_{config_name}" + return super().get(check_func_name) + + def get_configs(self, *config_names: str, columns_spec: List[int] = None): + """ + Retrieve and display multiple configurations in Streamlit columns. + + Args: + *config_names (str): Names of configurations to retrieve. + columns_spec (List[int], optional): Configuration for Streamlit columns. + """ + config_pair = [] + for config_name in config_names: + config_func = self.get(config_name) + if config_func is not None: + config_pair.append((config_name, config_func)) + if len(config_pair) == 0: + return + + if columns_spec is None: + columns_spec = len(config_pair) + columns = st.columns(columns_spec) + for col, (_, config_func) in zip(columns, config_pair): + with col: + config_func() + for config_name, _ in config_pair: + check_func = self.get_check_func(config_name) + if check_func is not None: + check_func(unfinished_fields=self.unfinished_fields) + + def _register_config( + self, + config_name: str, + config_func: Callable[[None], None], + default_value: Optional[Any] = None, + visible: Optional[Callable[[], bool]] = None, + other_configs: Optional[Dict[str, Any]] = None, + ): + """ + Internal method to register a configuration and its associated function. + + Args: + config_name (str): Name of the configuration. + config_func (Callable): Function to set the configuration. + default_value (Any, optional): Default value for the configuration. + visible (Callable, optional): visible for when the config should be visible/applicable. + other_configs (Dict[str, Any], optional): Additional configurations to register. + """ + assert config_name not in self._default_config, f"{config_name} already exists." + self._default_config[config_name] = default_value + if visible is not None: + self._config_visibles[config_name] = visible + if other_configs is not None: + for name, value in other_configs.items(): + assert name not in self._default_config, f"{name} already exists." + self._default_config[name] = value + super()._register_module(module_name=config_name, module_cls=config_func) + + def register_config( + self, + default_value: Optional[Any] = None, + config_func: Optional[Callable[[None], None]] = None, + visible: Optional[Callable[[], bool]] = None, + other_configs: Optional[Dict[str, Any]] = None, + ): + """ + Decorator to register a configuration function. + + The function name must start with 'set_', and the part after 'set_' becomes the config name. + + Note: This function will automatically pass `key=config_name` as an argument to the + registered configuration function. Ensure your function accepts this keyword argument. + + Args: + default_value (Any, optional): Default value for the configuration. + config_func (Callable, optional): The configuration function to register. + visible (Callable, optional): visible for when the config should be visible. + other_configs (Dict[str, Any], optional): Additional configurations to register. + + Returns: + A decorator function if config_func is None, else the registered config function. + """ + + # if config_func is None, should return a decorator function + def _register(config_func: Callable[[None], None]): + config_name = config_func.__name__ + prefix = "set_" + assert config_name.startswith( + prefix + ), f"Config function name should start with `{prefix}`, got {config_name}" + config_name = config_name[len(prefix) :] + config_func = partial(config_func, key=config_name) + self._register_config( + config_name=config_name, + config_func=config_func, + default_value=default_value, + visible=visible, + other_configs=other_configs, + ) + return config_func + + if config_func is not None: + return _register(config_func) + return _register + + def _register_check(self, config_name: str, check_func: Callable[[Set, str], None]): + """ + Internal method to register a check function for a configuration. + + Args: + config_name (str): Name of the configuration to check. + check_func (Callable): Function to check the configuration. + """ + assert config_name in self._default_config, f"`{config_name}` is not registered." + super()._register_module(module_name=f"check_{config_name}", module_cls=check_func) + + def register_check(self, check_func: Callable[[Set, str], None] = None): + """ + Decorator to register a check function for a configuration. + + The function name must start with 'check_', and the part after 'check_' should match a config name. + + Note: This function will automatically pass `key=config_name` and `unfinished_fields=self.unfinished_fields` as an argument to the registered check function. Ensure your function accepts these keyword arguments. + + Args: + check_func (Callable, optional): The check function to register. + + Returns: + A decorator function if check_func is None, else the registered check function. + """ + + def _register(check_func: Callable[[Set, str], None]): + config_name = check_func.__name__ + prefix = "check_" + assert config_name.startswith( + prefix + ), f"Check function name must start with `{prefix}`, got {config_name}" + config_name = config_name[len(prefix) :] + check_func = partial(check_func, key=config_name) + self._register_check(config_name, check_func) + return check_func + + if check_func is not None: + return _register(check_func) + return _register + + +# Global registry for configuration generators +CONFIG_GENERATORS = ConfigRegistry("config_generators") diff --git a/trinity/manager/config_registry/explorer_config_manager.py b/trinity/manager/config_registry/explorer_config_manager.py new file mode 100644 index 0000000000..9393187f60 --- /dev/null +++ b/trinity/manager/config_registry/explorer_config_manager.py @@ -0,0 +1,298 @@ +import streamlit as st + +from trinity.common.constants import AlgorithmType, SyncMethod +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num + + +def explorer_visible() -> bool: + return st.session_state["mode"] == "both" + + +@CONFIG_GENERATORS.register_config(default_value=32, visible=explorer_visible) +def set_runner_num(**kwargs): + st.number_input("Runner Num", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=900, visible=explorer_visible) +def set_max_timeout(**kwargs): + st.number_input("Max Timeout", min_value=0, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=2, visible=explorer_visible) +def set_explorer_max_retry_times(**kwargs): + st.number_input("Explorer Max Retry Times", min_value=0, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1000, visible=explorer_visible) +def set_eval_interval(**kwargs): + st.number_input("Eval Interval", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) +def set_eval_on_latest_checkpoint(**kwargs): + st.checkbox("Eval on Latest Checkpoint", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="vllm_async", visible=explorer_visible) +def set_engine_type(**kwargs): + st.selectbox("Engine Type", ["vllm_async", "vllm"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=2, visible=explorer_visible) +def set_engine_num(**kwargs): + key = kwargs.get("key") + total_gpu_num = st.session_state["total_gpu_num"] + max_engine_num = (total_gpu_num - 1) // st.session_state["tensor_parallel_size"] + if st.session_state[key] > max_engine_num: + st.session_state[key] = max_engine_num + set_trainer_gpu_num() + st.number_input( + "Engine Num", + min_value=1, + max_value=max_engine_num, + on_change=set_trainer_gpu_num, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1, visible=explorer_visible) +def set_tensor_parallel_size(**kwargs): + key = kwargs.get("key") + total_gpu_num = st.session_state["total_gpu_num"] + max_tensor_parallel_size = (total_gpu_num - 1) // st.session_state["engine_num"] + if st.session_state[key] > max_tensor_parallel_size: + st.session_state[key] = max_tensor_parallel_size + set_trainer_gpu_num() + st.number_input( + "Tensor Parallel Size", + min_value=1, + max_value=max_tensor_parallel_size, + on_change=set_trainer_gpu_num, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_check() +def check_tensor_parallel_size(unfinished_fields: set, key: str): + if st.session_state["trainer_gpu_num"] <= 0: + unfinished_fields.add("engine_num") + unfinished_fields.add("tensor_parallel_size") + st.warning( + "Please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that at least one GPU is reserved for the `trainer`." + ) + elif ( + st.session_state["node_num"] > 1 + and st.session_state["trainer_gpu_num"] % st.session_state["gpu_per_node"] != 0 + ): + unfinished_fields.add("engine_num") + unfinished_fields.add("tensor_parallel_size") + st.warning( + "When `node_num > 1`, please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that the number of GPUs reserved for the `trainer` is divisible by `gpu_per_node`" + ) + + +@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) +def set_use_v1(**kwargs): + st.checkbox("Use V1 Engine", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) +def set_enforce_eager(**kwargs): + st.checkbox("Enforce Eager", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) +def set_enable_prefix_caching(**kwargs): + st.checkbox("Prefix Caching", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) +def set_enable_chunked_prefill(**kwargs): + st.checkbox("Chunked Prefill", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=0.9, visible=explorer_visible) +def set_gpu_memory_utilization(**kwargs): + st.number_input("GPU Memory Utilization", min_value=0.0, max_value=1.0, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="bfloat16", visible=explorer_visible) +def set_dtype(**kwargs): + st.selectbox("Dtype", ["bfloat16", "float16", "float32"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=42, visible=explorer_visible) +def set_seed(**kwargs): + st.number_input("Seed", step=1, **kwargs) + + +# TODO: max_prompt_tokens +# TODO: max_response_tokens +# TODO: chat_template + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) +def set_enable_thinking(**kwargs): + st.checkbox("Enable Thinking For Qwen3", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) +def set_enable_openai_api(**kwargs): + st.checkbox("Enable OpenAI API", **kwargs) + + +def _set_auxiliary_model_idx(idx): + col1, col2 = st.columns([9, 1]) + col1.text_input( + "Model Path", + key=f"auxiliary_model_{idx}_model_path", + ) + if col2.button("✖️", key=f"auxiliary_model_{idx}_del_flag", type="primary"): + st.rerun() + + engine_type_col, engine_num_col, tensor_parallel_size_col = st.columns(3) + total_gpu_num = st.session_state["total_gpu_num"] + engine_type_col.selectbox( + "Engine Type", ["vllm_async"], key=f"auxiliary_model_{idx}_engine_type" + ) + engine_num_col.number_input( + "Engine Num", + min_value=1, + max_value=total_gpu_num - 1, + on_change=set_trainer_gpu_num, + key=f"auxiliary_model_{idx}_engine_num", + ) + tensor_parallel_size_col.number_input( + "Tensor Parallel Size", + min_value=1, + max_value=8, + on_change=set_trainer_gpu_num, + key=f"auxiliary_model_{idx}_tensor_parallel_size", + ) + + gpu_memory_utilization_col, dtype_col, seed_col = st.columns(3) + gpu_memory_utilization_col.number_input( + "GPU Memory Utilization", + min_value=0.0, + max_value=1.0, + key=f"auxiliary_model_{idx}_gpu_memory_utilization", + ) + dtype_col.selectbox( + "Dtype", ["bfloat16", "float16", "float32"], key=f"auxiliary_model_{idx}_dtype" + ) + seed_col.number_input("Seed", step=1, key=f"auxiliary_model_{idx}_seed") + + ( + use_v1_col, + enforce_eager_col, + enable_prefix_caching_col, + enable_chunked_prefill_col, + ) = st.columns(4) + use_v1_col.checkbox("Use V1 Engine", key=f"auxiliary_model_{idx}_use_v1") + enforce_eager_col.checkbox("Enforce Eager", key=f"auxiliary_model_{idx}_enforce_eager") + enable_prefix_caching_col.checkbox( + "Prefix Caching", key=f"auxiliary_model_{idx}_enable_prefix_caching" + ) + enable_chunked_prefill_col.checkbox( + "Chunked Prefill", key=f"auxiliary_model_{idx}_enable_chunked_prefill" + ) + + enable_thinking_col, enable_openai_api = st.columns(2) + enable_thinking_col.checkbox( + "Enable Thinking For Qwen3", key=f"auxiliary_model_{idx}_enable_thinking" + ) + enable_openai_api.checkbox("Enable OpenAI API", key=f"auxiliary_model_{idx}_enable_openai_api") + + +@CONFIG_GENERATORS.register_config(other_configs={"_auxiliary_models_num": 0}) +def set_auxiliary_models(**kwargs): + if st.button("Add Auxiliary Models"): + idx = st.session_state["_auxiliary_models_num"] + st.session_state[f"auxiliary_model_{idx}_engine_num"] = 1 + st.session_state[f"auxiliary_model_{idx}_tensor_parallel_size"] = 1 + st.session_state[f"auxiliary_model_{idx}_gpu_memory_utilization"] = 0.9 + st.session_state[f"auxiliary_model_{idx}_seed"] = 42 + st.session_state[f"auxiliary_model_{idx}_use_v1"] = True + st.session_state[f"auxiliary_model_{idx}_enforce_eager"] = True + st.session_state["_auxiliary_models_num"] += 1 + set_trainer_gpu_num() + if st.session_state["_auxiliary_models_num"] > 0: + tabs = st.tabs( + [f"Auxiliary Model {i + 1}" for i in range(st.session_state["_auxiliary_models_num"])] + ) + for idx, tab in enumerate(tabs): + with tab: + _set_auxiliary_model_idx(idx) + + +@CONFIG_GENERATORS.register_check() +def check_auxiliary_models(unfinished_fields: set, key: str): + if st.session_state["trainer_gpu_num"] <= 0: + unfinished_fields.add("engine_num") + unfinished_fields.add("tensor_parallel_size") + st.warning( + "Please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that at least one GPU is reserved for the `trainer`." + ) + elif ( + st.session_state["node_num"] > 1 + and st.session_state["trainer_gpu_num"] % st.session_state["gpu_per_node"] != 0 + ): + unfinished_fields.add("engine_num") + unfinished_fields.add("tensor_parallel_size") + st.warning( + "When `node_num > 1`, please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that the number of GPUs reserved for the `trainer` is divisible by `gpu_per_node`" + ) + + +# Synchronizer Configs + + +@CONFIG_GENERATORS.register_config( + default_value=SyncMethod.NCCL.value, + visible=explorer_visible, + other_configs={"_not_dpo_sync_method": SyncMethod.NCCL.value}, +) +def set_sync_method(**kwargs): + key = kwargs.get("key") + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state[key] = SyncMethod.CHECKPOINT.value + disabled = True + else: + st.session_state[key] = st.session_state["_not_dpo_sync_method"] + disabled = False + + def on_change(): + if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + st.session_state["_not_dpo_sync_method"] = st.session_state[key] + + st.selectbox( + "Sync Method", + [sync_method.value for sync_method in SyncMethod], + help="""`nccl`: the explorer and trainer sync model weights once every `sync_interval` steps. + +`checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_interval`.""", + disabled=disabled, + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=10, visible=explorer_visible) +def set_sync_interval(**kwargs): + st.number_input( + "Sync Interval", + min_value=1, + help="""The step interval at which the `explorer` and `trainer` synchronize model weight.""", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1200, visible=explorer_visible) +def set_sync_timeout(**kwargs): + st.number_input( + "Sync Timeout", + min_value=1, + help="The timeout value for the synchronization operation.", + **kwargs, + ) diff --git a/trinity/manager/config_registry/model_config_manager.py b/trinity/manager/config_registry/model_config_manager.py new file mode 100644 index 0000000000..837bf27679 --- /dev/null +++ b/trinity/manager/config_registry/model_config_manager.py @@ -0,0 +1,206 @@ +import os + +import streamlit as st + +from trinity.common.constants import AlgorithmType, MonitorType +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +from trinity.manager.config_registry.trainer_config_manager import use_critic +from trinity.trainer.verl.ray_trainer import AdvantageEstimator + + +def set_total_gpu_num(): + st.session_state["total_gpu_num"] = ( + st.session_state["gpu_per_node"] * st.session_state["node_num"] + ) + set_trainer_gpu_num() + + +def set_trainer_gpu_num(): + if st.session_state["mode"] == "both": + trainer_gpu_num = ( + st.session_state["total_gpu_num"] + - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"] + ) + for idx in range(st.session_state["_auxiliary_models_num"]): + engine_num = st.session_state[f"auxiliary_model_{idx}_engine_num"] + tensor_parallel_size = st.session_state[f"auxiliary_model_{idx}_tensor_parallel_size"] + trainer_gpu_num -= engine_num * tensor_parallel_size + st.session_state["trainer_gpu_num"] = trainer_gpu_num + else: # model == train + st.session_state["trainer_gpu_num"] = st.session_state["total_gpu_num"] + + +@CONFIG_GENERATORS.register_config(default_value="Trinity-RFT") +def set_project(**kwargs): + st.text_input("Project", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="qwen2.5-1.5B") +def set_exp_name(**kwargs): + st.text_input("Experiment Name", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="") +def set_checkpoint_root_dir(**kwargs): + st.text_input("Checkpoint Root Dir", **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_checkpoint_root_dir(unfinished_fields: set, key: str): + if not st.session_state[key].strip(): # TODO: may auto generate + unfinished_fields.add(key) + st.warning("Please input checkpoint root dir.") + elif not os.path.isabs(st.session_state[key].strip()): + unfinished_fields.add("checkpoint_root_dir") + st.warning("Please input an absolute path.") + + +@CONFIG_GENERATORS.register_config(default_value=MonitorType.TENSORBOARD.value) +def set_monitor_type(**kwargs): + st.selectbox( + "Monitor Type", + options=[monitor_type.value for monitor_type in MonitorType], + **kwargs, + ) + + +# Algorithm Configs + + +@CONFIG_GENERATORS.register_config( + default_value=AlgorithmType.PPO.value, + other_configs={"mode": "both", "adv_estimator": AdvantageEstimator.GAE.value}, +) +def set_algorithm_type(**kwargs): + def on_change(): + if st.session_state["algorithm_type"] == AlgorithmType.PPO.value: + st.session_state["mode"] = "both" + st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value + elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value: + st.session_state["mode"] = "both" + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["mode"] = "train" + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value: + st.session_state["mode"] = "both" + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + else: # TODO: add more algorithms + pass + set_trainer_gpu_num() + + st.selectbox( + "Algorithm Type", + [ + AlgorithmType.PPO.value, + AlgorithmType.GRPO.value, + AlgorithmType.DPO.value, + AlgorithmType.OPMD.value, + ], + key="algorithm_type", + on_change=on_change, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=1, + visible=lambda: st.session_state["mode"] == "both", + other_configs={ + "_grouped_adv_repeat_times": 2, + "_not_grouped_adv_repeat_times": 1, + }, +) +def set_repeat_times(**kwargs): # TODO + key = kwargs.get("key") + grouped_adv_algorithms = [ + AlgorithmType.GRPO.value, + AlgorithmType.OPMD.value, # TODO: may add rloo + ] + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + min_repeat_times = 2 + st.session_state[key] = st.session_state["_grouped_adv_repeat_times"] + else: + min_repeat_times = 1 + st.session_state[key] = st.session_state["_not_grouped_adv_repeat_times"] + + def on_change(): + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + st.session_state["_grouped_adv_repeat_times"] = st.session_state[key] + else: + st.session_state["_not_grouped_adv_repeat_times"] = st.session_state[key] + + st.number_input( + "Repeat Times", + min_value=min_repeat_times, + help="`repeat_times` is used to set how many experiences each task can generate, " + "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1.0) +def set_gamma(**kwargs): + st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1.0) +def set_lam(**kwargs): + st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs) + + +# Model Configs + + +@CONFIG_GENERATORS.register_config(default_value="") +def set_model_path(**kwargs): + st.text_input("Model Path", **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_model_path(unfinished_fields: set, key: str): + if not st.session_state[key].strip(): + unfinished_fields.add(key) + st.warning("Please input model path.") + + +@CONFIG_GENERATORS.register_config( + default_value="", + visible=use_critic, +) +def set_critic_model_path(**kwargs): + st.text_input( + "Critic Model Path (defaults to `model_path`)", + key="critic_model_path", + ) + + +@CONFIG_GENERATORS.register_config(default_value=1024) +def set_max_prompt_tokens(**kwargs): + st.number_input("Max Prompt Tokens", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1024) +def set_max_response_tokens(**kwargs): + st.number_input("Max Response Tokens", min_value=1, **kwargs) + + +# Cluster Config + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_node_num(**kwargs): + st.number_input("Node Num", min_value=1, on_change=set_total_gpu_num, **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=8, other_configs={"total_gpu_num": 8, "trainer_gpu_num": 6} +) +def set_gpu_per_node(**kwargs): + st.number_input( + "GPU Per Node", + min_value=1, + max_value=8, + on_change=set_total_gpu_num, + **kwargs, + ) diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py new file mode 100644 index 0000000000..d0f5d26897 --- /dev/null +++ b/trinity/manager/config_registry/trainer_config_manager.py @@ -0,0 +1,450 @@ +import streamlit as st + +from trinity.common.constants import AlgorithmType, SyncMethod +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +from trinity.trainer.verl.ray_trainer import AdvantageEstimator + + +def use_critic(): + return st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value + + +@CONFIG_GENERATORS.register_config(default_value="verl") +def set_trainer_type(**kwargs): + st.selectbox("Trainer Type", ["verl"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=100, other_configs={"_nccl_save_interval": 100}) +def set_save_interval(**kwargs): + key = kwargs.get("key") + if ( + st.session_state["algorithm_type"] == AlgorithmType.DPO.value + or st.session_state["sync_method"] == SyncMethod.NCCL.value + ): + st.session_state[key] = st.session_state["_nccl_save_interval"] + freeze_save_interval = False + else: + st.session_state[key] = st.session_state["sync_interval"] + freeze_save_interval = True + + def on_change(): + if ( + st.session_state["algorithm_type"] == AlgorithmType.DPO.value + or st.session_state["sync_method"] == SyncMethod.NCCL.value + ): + st.session_state["_nccl_save_interval"] = st.session_state[key] + + st.number_input( + "Save Interval", + min_value=1, + help="Set to `sync_interval` when `algorithm_type != DPO && sync_method == checkpoint`", + disabled=freeze_save_interval, + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=True) +def set_enable_preview(**kwargs): + st.checkbox("Enable Preview", **kwargs) + + +def _actor_use_kl_loss_visible(): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["actor_use_kl_loss"] = True + return False + return True + + +@CONFIG_GENERATORS.register_config( + default_value=True, + visible=_actor_use_kl_loss_visible, + other_configs={"_not_dpo_actor_use_kl_loss": True}, +) +def set_actor_use_kl_loss(**kwargs): + key = kwargs.get("key") + st.session_state[key] = st.session_state["_not_dpo_actor_use_kl_loss"] + + def on_change(): + st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[key] + + st.checkbox("Use KL Loss", on_change=on_change, **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"] +) +def set_actor_kl_loss_coef(**kwargs): + st.number_input( + r"KL Loss Coef :blue-badge[$\beta$]", + min_value=0.0, + max_value=1.0, + format="%.1e", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"] +) +def set_actor_entropy_coef(**kwargs): + st.number_input( + "Entropy Coeff", + min_value=0.0, + max_value=1.0, + format="%.1e", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1.0) +def set_actor_grad_clip(**kwargs): + st.number_input( + "Grad Clip :blue-badge[(Actor)]", + min_value=0.0, + max_value=1.0, + help="Clipping by Norm", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=0.2) +def set_actor_clip_ratio(**kwargs): + st.number_input( + r"Clip Ratio :blue-badge[$\epsilon$]", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +# veRL Trainer Configs + + +@CONFIG_GENERATORS.register_config( + default_value=[ + "balance_batch", + "gradient_checkpointing", + "remove_padding", + "dynamic_bsz", + ] +) +def set_training_args(**kwargs): + st.multiselect( + "Training Args", + [ + "balance_batch", + "gradient_checkpointing", + "remove_padding", + "dynamic_bsz", + ], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_ppo_epochs(**kwargs): + st.number_input("PPO Epochs", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="fsdp") +def set_training_strategy(**kwargs): + st.selectbox( + "Training Strategy", + ["fsdp", "megatron"], + help="megatron is not tested", + **kwargs, + ) + + +def use_fsdp(): + return st.session_state["training_strategy"] == "fsdp" + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp) +def set_param_offload(**kwargs): + st.checkbox("FSDP Param Offload", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp) +def set_optimizer_offload(**kwargs): + st.checkbox("FSDP Optimizer Offload", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="auto") +def set_resume_mode(**kwargs): + st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value="", visible=lambda: st.session_state["resume_mode"] == "resume_path" +) +def set_resume_from_path(**kwargs): + st.text_input("Resume Path", **kwargs) + + +@CONFIG_GENERATORS.register_check() +def check_resume_from_path(unfinished_fields: set, key: str): + if st.session_state["resume_mode"] == "resume_path" and ( + not st.session_state[key].strip() or "global_step_" not in st.session_state[key] + ): + unfinished_fields.add(key) + st.warning("Please input a valid resume path when `resume_mode == resume_path`") + + +@CONFIG_GENERATORS.register_config(default_value=0) +def set_critic_warmup(**kwargs): + st.number_input("Critic Warmup Steps", min_value=0, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_total_training_steps(**kwargs): + st.number_input("Total Training Steps", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_default_hdfs_dir(**kwargs): + st.text_input("Default HDFS Dir", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False) +def set_remove_previous_ckpt_in_save(**kwargs): + st.checkbox("Remove Previous Checkpoint in Save", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False) +def set_del_local_ckpt_after_load(**kwargs): + st.checkbox("Delete Local Checkpoint After Load", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_max_actor_ckpt_to_keep(**kwargs): + st.number_input("Max Actor Checkpoint to Keep", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_max_critic_ckpt_to_keep(**kwargs): + st.number_input("Max Critic Checkpoint to Keep", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=True) +def set_norm_adv_by_std_in_grpo(**kwargs): + st.checkbox("Norm Adv by Std in GRPO", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False) +def set_use_kl_in_reward(**kwargs): + st.checkbox("Use KL in Reward", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="low_var_kl") +def set_kl_penalty(**kwargs): + st.selectbox("KL Penalty", ["kl", "abs", "mse", "low_var_kl"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="fixed") +def set_kl_ctrl_type(**kwargs): + st.selectbox("KL Ctrl Type", ["fixed", "adaptive"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=0.001) +def set_kl_ctrl_coef(**kwargs): + st.number_input("KL Ctrl Coef", format="%.1e", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=10000) +def set_horizon(**kwargs): + st.number_input("Horizon", min_value=1.0, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=0.1) +def set_target_kl(**kwargs): + st.number_input("Target KL", format="%.1e", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=4) +def set_actor_ppo_micro_batch_size_per_gpu(**kwargs): + key = kwargs.get("key") + max_value = st.session_state["_train_batch_size_per_gpu"] + st.session_state[key] = min(st.session_state[key], max_value) + st.number_input( + "Micro Batch Size Per GPU :blue-badge[(Actor)]", min_value=1, max_value=max_value, **kwargs + ) + + +@CONFIG_GENERATORS.register_config(default_value=8) +def set_ref_log_prob_micro_batch_size_per_gpu(**kwargs): + key = kwargs.get("key") + max_value = st.session_state["_train_batch_size_per_gpu"] + st.session_state[key] = min(st.session_state[key], max_value) + st.number_input( + "Micro Batch Size Per GPU :blue-badge[(Ref)]", min_value=1, max_value=max_value, **kwargs + ) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_actor_ulysses_sequence_parallel_size(**kwargs): + st.number_input( + "Ulysses Sequence Parallel Size", + min_value=1, + max_value=8, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1e-6) +def set_actor_lr(**kwargs): + st.number_input( + "Learning Rate :blue-badge[(Actor)]", + min_value=1e-7, + max_value=1e-3, + format="%.1e", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value="constant") +def set_actor_warmup_style(**kwargs): + st.selectbox( + "LR Warmup Style :blue-badge[(Actor)]", + ["constant", "cosine"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=0.0) +def set_actor_lr_warmup_steps_ratio(**kwargs): + st.number_input( + "LR Warmup Steps Ratio :blue-badge[(Actor)]", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=0.0, visible=lambda: st.session_state["algorithm_type"] == "opmd" +) +def set_actor_tau(**kwargs): + st.number_input("Tau for OPMD", min_value=0.0, format="%.1e", **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value="mean", visible=lambda: st.session_state["algorithm_type"] == "opmd" +) +def set_actor_opmd_baseline(**kwargs): + st.selectbox( + "OPMD Baseline", + ["mean", "logavgexp"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=False, visible=lambda: st.session_state["algorithm_type"] == "opmd" +) +def set_actor_use_uid(**kwargs): + st.checkbox("Use UID for OPMD", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value="low_var_kl") +def set_actor_kl_loss_type(**kwargs): + st.selectbox( + "KL Loss Type", + ["kl", "abs", "mse", "low_var_kl"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=["model", "hf_model", "optimizer", "extra"]) +def set_actor_checkpoint(**kwargs): + st.multiselect( + "Checkpoint", + ["model", "hf_model", "optimizer", "extra"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1e-6, visible=use_critic) +def set_critic_lr(**kwargs): + st.number_input( + "Learning Rate :blue-badge[(Critic)]", + min_value=1e-7, + max_value=1e-3, + format="%.1e", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value="constant", visible=use_critic) +def set_critic_warmup_style(**kwargs): + st.selectbox( + "LR Warmup Style :blue-badge[(Critic)]", + ["constant", "cosine"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=0.0, visible=use_critic) +def set_critic_lr_warmup_steps_ratio(**kwargs): + st.number_input( + "LR Warmup Steps Ratio :blue-badge[(Critic)]", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1.0, visible=use_critic) +def set_critic_grad_clip(**kwargs): + st.number_input( + "Grad Clip :blue-badge[(Critic)]", + min_value=0.0, + max_value=1.0, + help="Clipping by Norm", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=0.5, visible=use_critic) +def set_critic_cliprange_value(**kwargs): + st.number_input( + "Cliprange Value", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=8, visible=use_critic) +def set_critic_ppo_micro_batch_size_per_gpu(**kwargs): + key = kwargs.get("key") + max_value = st.session_state["_train_batch_size_per_gpu"] + st.session_state[key] = min(st.session_state[key], max_value) + st.number_input( + "Micro Batch Size Per GPU :blue-badge[(Critic)]", + min_value=1, + max_value=max_value, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1, visible=use_critic) +def set_critic_ulysses_sequence_parallel_size(**kwargs): + st.number_input( + "Ulysses Sequence Parallel Size", + min_value=1, + max_value=8, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=["model", "optimizer", "extra"], visible=use_critic +) +def set_critic_checkpoint(**kwargs): + st.multiselect( + "Checkpoint", + ["model", "hf_model", "optimizer", "extra"], + **kwargs, + ) diff --git a/trinity/plugins/__init__.py b/trinity/plugins/__init__.py new file mode 100644 index 0000000000..1b8629c9ca --- /dev/null +++ b/trinity/plugins/__init__.py @@ -0,0 +1 @@ +"""Add your custom modules to this directory.""" diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 595084ac02..616234d0d6 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -314,11 +314,6 @@ def update_policy(self, data: DataProto): # noqa: C901 else: dataloader = batch.split(self.config.ppo_mini_batch_size) - # TODO: for pairwise_opmd and use_uid, is it necessary to somehow sort samples within batch by uid, - # to ensure that there are samples with the same uid within each micro-batch - # (at which level pairwise loss is computed)? - # (In comparison, advantage is computed at the level of batch, same for opmd, grpo, etc.) - metrics = {} for epoch in range(self.config.ppo_epochs): for batch_idx, data in enumerate(dataloader): diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index d040c329dd..e7eb8a209b 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -40,7 +40,7 @@ from trinity.common.config import Config from trinity.common.experience import Experiences from trinity.trainer.trainer import TrainEngineWrapper -from trinity.utils.monitor import Monitor +from trinity.utils.monitor import MONITOR class _InternalDataLoader: @@ -145,7 +145,7 @@ def __init__( ) self.init_workers() - self.logger = Monitor( + self.logger = MONITOR.get(global_config.monitor.monitor_type)( project=config.trainer.project_name, name=config.trainer.experiment_name, role="trainer", diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py index 8d7a7d3a06..b250d856d6 100644 --- a/trinity/utils/dlc_utils.py +++ b/trinity/utils/dlc_utils.py @@ -9,6 +9,20 @@ logger = get_logger(__name__) +CLUSTER_ACTOR_NAME = "cluster_status" + + +@ray.remote +class ClusterStatus: + def __init__(self): + self.finished = False + + def finish(self) -> None: + self.finished = True + + def running(self) -> bool: + return not self.finished + def get_dlc_env_vars() -> dict: envs = { @@ -71,16 +85,40 @@ def setup_ray_cluster(namespace: str): logger.error(f"ret.stdout: {ret.stdout!r}") logger.error(f"ret.stderr: {ret.stderr!r}") sys.exit(1) + + wait_for_ray_setup() + ray.init( + address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", + namespace=namespace, + ignore_reinit_error=True, + ) if is_master: - wait_for_ray_setup() - ray.init( - address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", - namespace=namespace, - ignore_reinit_error=True, - ) # master wait for worker nodes to join wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"]) + else: + # woker wait on the cluster status actor + cluster_status = ClusterStatus.options( + name=CLUSTER_ACTOR_NAME, + get_if_exists=True, + ).remote() + while True: + if ray.get(cluster_status.running.remote()): + ret = subprocess.run("ray status", shell=True, capture_output=True) + print(ret.stdout.decode()) + time.sleep(5) + else: + logger.info("Ray cluster is not running, exiting.") + break + sys.exit(0) + - if not is_master: - # woker just exit - sys.exit(0) +def stop_ray_cluster(): + """ + Stop the ray cluster by sending a signal to the cluster status actor. + """ + cluster_status = ClusterStatus.options( + name=CLUSTER_ACTOR_NAME, + get_if_exists=True, + ).remote() + ray.get(cluster_status.finish.remote()) + logger.info("Stopping ray cluster...") diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py index e3aa216eda..e80afaf59b 100644 --- a/trinity/utils/eval_utils.py +++ b/trinity/utils/eval_utils.py @@ -15,12 +15,19 @@ def simple_answer_parser(response: str) -> str: return parse(response) -def find_boxed_answer(string): +def find_boxed_answer(raw_answer, timeout=10): """ - Find answers from solutions where the answers are enclosed in LaTeX's `\boxed` tag + Find answers from solutions where the answers are enclosed in LaTeX's `\\boxed` tag + + Args: + raw_answer (`str`): raw answer from model + timeout (`int`): timeout in seconds for regex + + Returns: + `str`: answer if found, otherwise None """ pattern = r"\\boxed\s*(({(?:\\.|[^{}]|(?2))*})|(.))" - res = re.findall(pattern, string) + res = re.findall(pattern, raw_answer, timeout=timeout) if res: answer = res[-1][0] # regard the last boxed as the answer if answer.startswith("{"): diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 3044c6dcc8..f12a854335 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -1,5 +1,7 @@ """Monitor""" + import os +from abc import ABC, abstractmethod from typing import List, Optional, Union import numpy as np @@ -8,11 +10,13 @@ from torch.utils.tensorboard import SummaryWriter from trinity.common.config import Config -from trinity.common.constants import MonitorType from trinity.utils.log import get_logger +from trinity.utils.registry import Registry + +MONITOR = Registry("monitor") -class Monitor: +class Monitor(ABC): """Monitor""" def __init__( @@ -22,15 +26,25 @@ def __init__( role: str, config: Config = None, # pass the global Config for recording ) -> None: - if config.monitor.monitor_type == MonitorType.WANDB: - self.logger = WandbLogger(project, name, role, config) - elif config.monitor.monitor_type == MonitorType.TENSORBOARD: - self.logger = TensorboardLogger(project, name, role, config) - else: - raise ValueError(f"Unknown monitor type: {config.monitor.monitor_type}") + self.project = project + self.name = name + self.role = role + self.config = config + @abstractmethod def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): - self.logger.log_table(table_name, experiences_table, step=step) + """Log a table""" + + @abstractmethod + def log(self, data: dict, step: int, commit: bool = False) -> None: + """Log metrics.""" + + @abstractmethod + def close(self) -> None: + """Close the monitor""" + + def __del__(self) -> None: + self.close() def calculate_metrics( self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None @@ -51,15 +65,9 @@ def calculate_metrics( metrics[key] = val return metrics - def log(self, data: dict, step: int, commit: bool = False) -> None: - """Log metrics.""" - self.logger.log(data, step=step, commit=commit) - - def close(self) -> None: - self.logger.close() - -class TensorboardLogger: +@MONITOR.register_module("tensorboard") +class TensorboardMonitor(Monitor): def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard") os.makedirs(self.tensorboard_dir, exist_ok=True) @@ -77,11 +85,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: def close(self) -> None: self.logger.close() - def __del__(self) -> None: - self.logger.close() - -class WandbLogger: +@MONITOR.register_module("wandb") +class WandbMonitor(Monitor): def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: self.logger = wandb.init( project=project, @@ -104,6 +110,3 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: def close(self) -> None: self.logger.finish() - - def __del__(self) -> None: - self.logger.finish() diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py new file mode 100644 index 0000000000..a5a779ae83 --- /dev/null +++ b/trinity/utils/plugin_loader.py @@ -0,0 +1,65 @@ +"""Load modules from custom directory""" + +import importlib +import os +import shutil +import sys +from pathlib import Path + +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +def load_plugins(plugin_dir: str) -> None: + """ + Load plugin modules from a directory. + """ + if plugin_dir is None: + plugin_dir = Path(__file__).parent.parent / "plugins" + if not os.path.exists(plugin_dir): + logger.error(f"--plugin-dir [{plugin_dir}] does not exist.") + return None + if not os.path.isdir(plugin_dir): + logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.") + return None + + logger.info(f"Loading plugin modules from [{plugin_dir}]...") + for file in Path(plugin_dir).glob("*.py"): + if file.name.startswith("__"): + continue + logger.info(f"Loading plugin modules from [{file}]...") + # load modules from file + load_from_file(os.path.join(plugin_dir, file)) + + +def load_from_file(file_path: str): + """ + Load modules from a Python file + + Args: + file_path (`str`): The python file path. + + Returns: + `Any`: The loaded module. + """ + module_name = os.path.splitext(os.path.basename(file_path))[0] + + full_module_name = f"trinity.plugins.{module_name}" + + spec = importlib.util.spec_from_file_location(full_module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load module from {file_path}") + + module = importlib.util.module_from_spec(spec) + + module.__package__ = "trinity.plugins" + + spec.loader.exec_module(module) + + if full_module_name in sys.modules: + raise ImportError(f"Module {module_name} already exists.") + sys.modules[full_module_name] = module + shutil.copy2(file_path, Path(__file__).parent.parent / "plugins") + logger.info(f"Load {file_path} as {full_module_name}") + return module diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py index b31f6872bd..d5ee37f36e 100644 --- a/trinity/utils/registry.py +++ b/trinity/utils/registry.py @@ -1,21 +1,4 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -------------------------------------------------------- -# Most of the code here has been modified from: -# https://github.com/modelscope/modelscope/blob/master/modelscope/utils/registry.py -# -------------------------------------------------------- +from typing import Any, Type from trinity.utils.log import get_logger @@ -25,59 +8,57 @@ # TODO: support lazy load # e.g. @MODULES.register_module("name", lazy=True) class Registry(object): - """This class is used to register some modules to registry by a repo - name.""" + """A class for registry.""" def __init__(self, name: str): """ - Initialization method. - - :param name: a registry repo name + Args: + name (`str`): The name of the registry. """ self._name = name self._modules = {} @property - def name(self): + def name(self) -> str: """ Get name of current registry. - :return: name of current registry. + Returns: + `str`: The name of current registry. """ return self._name @property - def modules(self): + def modules(self) -> dict: """ Get all modules in current registry. - :return: a dict storing modules in current registry. + Returns: + `dict`: A dict storing modules in current registry. """ return self._modules - def list(self): + def list(self) -> None: """Logging the list of module in current registry.""" for m in self._modules.keys(): logger.info(f"{self._name}\t{m}") - def get(self, module_key): + def get(self, module_key) -> Any: """ Get module named module_key from in current registry. If not found, return None. - :param module_key: specified module name - :return: module named module_key + Args: + module_key (`str`): specified module name + + Returns: + `Any`: the module object """ return self._modules.get(module_key, None) def _register_module(self, module_name=None, module_cls=None, force=False): """ Register module to registry. - - :param module_name: module name - :param module_cls: module class object - :param force: Whether to override an existing class with the - same name. Default: False. """ if module_name is None: @@ -89,25 +70,35 @@ def _register_module(self, module_name=None, module_cls=None, force=False): self._modules[module_name] = module_cls module_cls._name = module_name - def register_module(self, module_name: str = None, module_cls: type = None, force=False): + def register_module(self, module_name: str, module_cls: Type = None, force=False, lazy=False): """ - Register module class object to registry with the specified modulename. + Register module class object to registry with the specified module name. - :param module_name: module name - :param module_cls: module class object - :param force: Whether to override an existing class with - the same name. Default: False. + Args: + module_name (`str`): The module name. + module_cls (`Type`): module class object + force (`bool`): Whether to override an existing class with + the same name. Default: False. + lazy (`bool`): Whether to register the module class object lazily. + Default: False. Example: - >>> registry = Registry() - >>> @registry.register_module() - >>> class TextFormatter: - >>> pass - - >>> class TextFormatter2: - >>> pass - >>> registry.register_module( module_name='text_formatter2', - module_cls=TextFormatter2) + ```python + WORKFLOWS = Registry("workflows") + + # register a module using decorator + @WORKFLOWS.register_module(name="workflow_name") + class MyWorkflow(Workflow): + pass + + # or register a module directly + WORKFLOWS.register_module( + name="workflow_name", + module_cls=MyWorkflow, + force=True, + ) + ``` + """ if not (module_name is None or isinstance(module_name, str)): raise TypeError(f"module_name must be either of None, str," f"got {type(module_name)}") @@ -120,8 +111,10 @@ def _register(module_cls): """ Register module class object to registry. - :param module_cls: module class object - :return: module class object. + Args: + module_cls (`Type`): module class object + Returns: + `Type`: Decorated module class object. """ self._register_module(module_name=module_name, module_cls=module_cls, force=force) return module_cls