From 74b677833ec40982257340a602e0e61b0d766b70 Mon Sep 17 00:00:00 2001 From: yuchang Date: Mon, 28 Apr 2025 14:00:56 +0800 Subject: [PATCH 1/5] make 'get_exp_strategy' effective in sql --- .gitignore | 3 +++ examples/grpo_gsm8k/gsm8k.yaml | 1 + trinity/buffer/reader/sql_reader.py | 2 +- trinity/common/config.py | 1 + trinity/common/verl_config.py | 1 - trinity/trainer/trainer.py | 8 ++++++-- trinity/utils/monitor.py | 2 +- 7 files changed, 13 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 5ffba098fc..646848ade7 100644 --- a/.gitignore +++ b/.gitignore @@ -94,3 +94,6 @@ modules.rst # wandb wandb/ + +# checkpoints +checkpoints/ diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index 63850d5d24..fd6a9b5c44 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -69,6 +69,7 @@ trainer: trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml' sft_warmup_iteration: 0 # Set to integer to enable sft warmup eval_interval: 50 + # get_exp_strategy: 'LFU' monitor: cache_root_dir: "" project: "Trinity-RFT-gsm8k" diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index d7a826fdfa..e5c249f441 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -49,7 +49,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List: sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id)) else: - raise NotImplementedError("Unsupported strategy by SQLStorage") + raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage") exp_list = [] while len(exp_list) < self.batch_size: diff --git a/trinity/common/config.py b/trinity/common/config.py index 587c7c5fe8..bf3b215745 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -181,6 +181,7 @@ class TrainerConfig: # train algorithm algorithm_type: AlgorithmType = AlgorithmType.PPO + get_exp_strategy: Optional[str] = None # warmup config sft_warmup_iteration: int = 0 diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 966fde0391..46676e62a8 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -244,7 +244,6 @@ class Trainer: training_rollout_mode: str = "parallel" enable_exp_buffer: bool = True steps_per_epoch: int = 1280 - get_exp_strategy: Optional[str] = None sync_freq: int = 0 sft_warmup_iteration: int = 0 max_actor_ckpt_to_keep: Optional[int] = None diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 925baa9400..cc0c90e0d5 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -13,7 +13,7 @@ from trinity.buffer import get_buffer_reader from trinity.common.config import Config -from trinity.common.constants import AlgorithmType +from trinity.common.constants import AlgorithmType, ReadStrategy from trinity.common.experience import Experiences from trinity.utils.log import get_logger @@ -82,7 +82,11 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple ) ) else: - exps = self.train_buffer.read() + if self.config.trainer.get_exp_strategy: + strategy = ReadStrategy(self.config.trainer.get_exp_strategy) + else: + strategy = None + exps = self.train_buffer.read(strategy=strategy) if algo_type.is_rft(): return self.engine.train_rft_iteration( Experiences.gather_experiences( diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 23b96a3c11..413bff7261 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -4,9 +4,9 @@ import numpy as np import pandas as pd -import wandb from torch.utils.tensorboard import SummaryWriter +import wandb from trinity.common.constants import MonitorType from trinity.utils.log import get_logger From 6bfe66caf9564b97319162a2b02b00e6d620424c Mon Sep 17 00:00:00 2001 From: yuchang Date: Mon, 28 Apr 2025 14:33:25 +0800 Subject: [PATCH 2/5] fix pre-commit problem --- pyproject.toml | 3 +++ trinity/utils/monitor.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8a6f226c77..d2a52385d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,3 +93,6 @@ exclude = ''' | dist )/ ''' + +[tool.isort] +known_third_party = ["wandb"] diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 413bff7261..23b96a3c11 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -4,9 +4,9 @@ import numpy as np import pandas as pd +import wandb from torch.utils.tensorboard import SummaryWriter -import wandb from trinity.common.constants import MonitorType from trinity.utils.log import get_logger From e434818ae7be78800f5c1d6e6239c0da8fff65b0 Mon Sep 17 00:00:00 2001 From: yuchang Date: Mon, 28 Apr 2025 15:20:31 +0800 Subject: [PATCH 3/5] fix figure show problem --- .../source/tutorial/example_data_functionalities.md | 2 +- docs/sphinx_doc/source/tutorial/example_multi_turn.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md index d7528fa0df..20242312f2 100644 --- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md @@ -244,7 +244,7 @@ You can set more config items for this OP (e.g. notification when annotation is When you start running with the RFT config, the data module will start the OP `human_preference_annotation_mapper`, and then you can find a new project on the "Projects" page of the label-studio server. -![]("../../assets/data-projects.png") +![](../../assets/data-projects.png) You can click and enter into this project, and all the samples that need to be annotated are listed on the page. diff --git a/docs/sphinx_doc/source/tutorial/example_multi_turn.md b/docs/sphinx_doc/source/tutorial/example_multi_turn.md index 9002fffadd..d70528b6ed 100644 --- a/docs/sphinx_doc/source/tutorial/example_multi_turn.md +++ b/docs/sphinx_doc/source/tutorial/example_multi_turn.md @@ -122,5 +122,5 @@ and include them in the init files in `trinity/common/workflows/__init__.py` Then you are all set! It should be pretty simple😄, and both environments converge. -![]("../../assets/alfworld_reward_curve.png") -![]("../../assets/webshop_reward_curve.png") +![](../../assets/alfworld_reward_curve.png) +![](../../assets/webshop_reward_curve.png) From 388d72c2fe57f900b4e6ccd77765e598dc943fe1 Mon Sep 17 00:00:00 2001 From: yuchang Date: Mon, 28 Apr 2025 15:51:08 +0800 Subject: [PATCH 4/5] fix explore --- trinity/explorer/explorer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index ecf26c6366..6f6c44b69c 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -149,8 +149,10 @@ def get_weight(self, name: str) -> torch.Tensor: def explore(self) -> None: """Explore the entire dataset.""" - explore_status, _ = self.explore_step() - while explore_status: + while True: + explore_status, _ = self.explore_step() + if not explore_status: + break self.sync_weight() self.logger.info("Explorer finished.") From af38c8511b754d577a15f360294f0155d07e368c Mon Sep 17 00:00:00 2001 From: yuchang Date: Mon, 28 Apr 2025 17:19:16 +0800 Subject: [PATCH 5/5] fix input_param bug --- trinity/trainer/trainer.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index cc0c90e0d5..fb35d087a1 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -81,28 +81,28 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple pad_token_id=self.config.buffer.pad_token_id, # type: ignore ) ) - else: + elif algo_type.is_rft(): if self.config.trainer.get_exp_strategy: strategy = ReadStrategy(self.config.trainer.get_exp_strategy) else: strategy = None exps = self.train_buffer.read(strategy=strategy) - if algo_type.is_rft(): - return self.engine.train_rft_iteration( - Experiences.gather_experiences( - exps, - pad_token_id=self.config.buffer.pad_token_id, # type: ignore - ) + return self.engine.train_rft_iteration( + Experiences.gather_experiences( + exps, + pad_token_id=self.config.buffer.pad_token_id, # type: ignore ) - elif algo_type.is_dpo(): - return self.engine.train_dpo_iteration( - Experiences.gather_dpo_experiences( - exps, - pad_token_id=self.config.buffer.pad_token_id, # type: ignore - ) + ) + elif algo_type.is_dpo(): + exps = self.train_buffer.read() + return self.engine.train_dpo_iteration( + Experiences.gather_dpo_experiences( + exps, + pad_token_id=self.config.buffer.pad_token_id, # type: ignore ) - else: - raise ValueError(f"Unsupported algorithm type: {algo_type}") + ) + else: + raise ValueError(f"Unsupported algorithm type: {algo_type}") def sync_weight(self) -> None: """Sync the model weight."""