From 114ad1906f52f278abc317af7e314352348cd470 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 19 Nov 2025 11:30:35 +0800 Subject: [PATCH 1/8] bug fix in benchmark ckpt loading and megatron hf save --- trinity/manager/synchronizer.py | 5 +- .../verl/megatron_checkpoint_manager.py | 72 ++++++++++--------- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 94d11ae7f8..e5adb86f1c 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -44,7 +44,10 @@ def __init__(self, config: Config, module_ref: ray.actor.ActorHandle): self._modules = {module_ref} self._modules_lock = asyncio.Lock() asyncio.create_task(self._check_modules()) - if self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: + if ( + self.config.mode != "bench" + and self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT + ): asyncio.create_task(self._find_latest_state_dict()) async def add_module(self, module_ref: ray.actor.ActorHandle) -> None: diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index b65b943782..98aed32f62 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -233,45 +233,51 @@ def save_checkpoint( # noqa: C901 json.dump(transformer_config_dict, f, indent=2) if self.should_save_hf_model or save_as_hf: - # wait for everyone to dump to local - state_dict = self.weight_saver( - self.model, - self.hf_config, - dtype=self.param_dtype, - is_value_model=self.is_value_model, - tie_word_embeddings=self.share_embeddings_and_output_weights, - ) + try: + # wait for everyone to dump to local + state_dict = self.weight_saver( + self.model, + self.hf_config, + dtype=self.param_dtype, + is_value_model=self.is_value_model, + tie_word_embeddings=self.share_embeddings_and_output_weights, + ) - torch.distributed.barrier() - if self.rank == 0: - # TODO: async save or use mbridge to save hf model - hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) - import warnings + torch.distributed.barrier() + if self.rank == 0: + # TODO: async save or use mbridge to save hf model + hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) + import warnings - from accelerate import init_empty_weights + from accelerate import init_empty_weights - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - if "mistral7b-rm" in self.config.model.path: - from transformers import MistralForSequenceClassification + with init_empty_weights(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + if "mistral7b-rm" in self.config.model.path: + from transformers import MistralForSequenceClassification - model = MistralForSequenceClassification.from_pretrained( - self.config.model.path - ) # use score head instead of lm_head - state_dict["score.weight"] = state_dict["score.weight"] - else: - from transformers import AutoModelForCausalLM + model = MistralForSequenceClassification.from_pretrained( + self.config.model.path + ) # use score head instead of lm_head + state_dict["score.weight"] = state_dict["score.weight"] + else: + from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_pretrained( - self.config.model.path, torch_dtype="auto" - ) - model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) - log_with_rank( - f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, + model = AutoModelForCausalLM.from_pretrained( + self.config.model.path, torch_dtype="auto" + ) + model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) + log_with_rank( + f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + except Exception as e: + logger.error( + f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it." ) + logger.error(e) ray.get( self.checkpoint_monitor.register_thread_count.remote( From a935c2d7f717259401cb63b7a862e0c49690a1e3 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 19 Nov 2025 13:13:28 +0800 Subject: [PATCH 2/8] add `model_dtype` --- trinity/common/verl_config.py | 1 + trinity/trainer/verl/fsdp_workers.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 49a241393a..71028b988c 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -79,6 +79,7 @@ class FSDPConfig: wrap_policy: WrapPolicy = field(default_factory=WrapPolicy) fsdp_size: int = -1 forward_prefetch: bool = False + model_dtype: Optional[str] = None @dataclass diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 899e991432..ff87ba87c3 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -246,7 +246,7 @@ def _build_model_optimizer( # noqa: C901 else: self.tokenizer.chat_template = self.config.model.custom_chat_template - torch_dtype = fsdp_config.get("model_dtype", None) + torch_dtype = fsdp_config.model_dtype if torch_dtype is None: torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 else: @@ -1014,7 +1014,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 if self.rank == 0: print(f"Critic overriding config {override_config_kwargs}") - torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") + torch_dtype = self.config.model.fsdp_config.model_dtype or "fp32" torch_dtype = PrecisionType.to_dtype(torch_dtype) from transformers import AutoConfig From 28586ebc2b57c775d8951423d640fa9169bbae76 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 19 Nov 2025 15:27:46 +0800 Subject: [PATCH 3/8] apply suggestions and remove to(dtype) in fsdp_workers --- trinity/trainer/verl/fsdp_workers.py | 6 ------ trinity/trainer/verl/megatron_checkpoint_manager.py | 6 +++--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 9d643de041..4962e1462d 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -326,9 +326,6 @@ def _build_model_optimizer( # noqa: C901 fused_kernels_backend=fused_kernels_backend, ) - # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 - actor_module.to(torch_dtype) - if enable_gradient_checkpointing: actor_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} @@ -1060,9 +1057,6 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 ulysses_sp_size=self.ulysses_sequence_parallel_size, ) - # some parameters may not in torch_dtype - critic_module.to(torch_dtype) - if config.model.get("enable_gradient_checkpointing", False): critic_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index 98aed32f62..f24c8aa4ef 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -273,11 +273,11 @@ def save_checkpoint( # noqa: C901 logger=logger, log_only_rank_0=True, ) - except Exception as e: + except Exception: logger.error( - f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it." + f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it.", + exc_info=True, ) - logger.error(e) ray.get( self.checkpoint_monitor.register_thread_count.remote( From 777ab91a0b165ed49b26f04efcfbbc7ed6da1766 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 19 Nov 2025 16:02:41 +0800 Subject: [PATCH 4/8] 1. remove `kl_loss_coef` and `kl_loss_type` in verl_config. 2. check `micro_batch_size` when not using `use_dynamic_bsz` in fsdp_workers. --- trinity/common/verl_config.py | 2 -- trinity/trainer/verl/fsdp_workers.py | 13 ++++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index f6239fd75d..eddcbba727 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -164,8 +164,6 @@ class Actor: clip_ratio_high: Optional[float] = None entropy_coeff: float = 0.001 use_kl_loss: bool = False - kl_loss_coef: float = 0.0 - kl_loss_type: str = "low_var_kl" @dataclass diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 4962e1462d..3e508c9124 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -168,7 +168,10 @@ def __init__(self, config: DictConfig, role: str): self.config.actor.ppo_micro_batch_size ) - if self.config.actor.ppo_micro_batch_size_per_gpu is not None: + if ( + not self.config.actor.use_dynamic_bsz + and self.config.actor.ppo_micro_batch_size_per_gpu is not None + ): assert ( self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu @@ -181,7 +184,11 @@ def __init__(self, config: DictConfig, role: str): ), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" # normalize ref config - if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: + if ( + self._is_ref + and not self.config.ref.log_prob_use_dynamic_bsz + and self.config.ref.log_prob_micro_batch_size is not None + ): self.config.ref.log_prob_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) @@ -968,7 +975,7 @@ def __init__(self, config): self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size - if self.config.ppo_micro_batch_size_per_gpu is not None: + if not self.config.use_dynamic_bsz and self.config.ppo_micro_batch_size_per_gpu is not None: assert ( self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" From ae3eb7dec95a3cf857f80fb5f401efcf71e17a7d Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 19 Nov 2025 18:03:23 +0800 Subject: [PATCH 5/8] fix mode check --- trinity/cli/launcher.py | 3 ++- trinity/common/config.py | 15 +++++++++------ trinity/explorer/explorer.py | 6 +++++- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 468ab2df53..d2d4254894 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -37,6 +37,7 @@ def bench(config: Config) -> None: def explore(config: Config) -> None: """Run explorer.""" + check_and_run_task_pipeline(config) try: explorer = Explorer.get_actor(config) ray.get(explorer.prepare.remote()) @@ -81,6 +82,7 @@ def both(config: Config) -> None: the latest step. The specific number of experiences may vary for different algorithms and tasks. """ + check_and_run_task_pipeline(config) try: explorer = Explorer.get_actor(config) trainer = Trainer.get_actor(config) @@ -151,7 +153,6 @@ def run_stage(config: Config) -> None: ) pprint(config) try: - check_and_run_task_pipeline(config) MODE_MAP[config.mode](config) finally: if config.monitor.enable_ray_timeline: diff --git a/trinity/common/config.py b/trinity/common/config.py index 3f2c44eb53..7537084cc1 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -853,8 +853,8 @@ def _check_interval(self) -> None: ) def _check_explorer_input(self) -> None: - if self.mode == "train": - # no need to check explorer_input in train mode + if self.mode in {"train", "bench", "serve"}: + # no need to check explorer_input in train/bench/serve mode return explorer_input = self.buffer.explorer_input @@ -866,9 +866,8 @@ def _check_explorer_input(self) -> None: explorer_input.taskset = None elif len(explorer_input.tasksets) == 0: raise ValueError("At least one taskset should be provided in explorer_input!") - tasksets = explorer_input.tasksets - for i, taskset in enumerate(tasksets): + for i, taskset in enumerate(explorer_input.tasksets): if self.mode != "train" and not taskset.path: raise ValueError( "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset." @@ -914,6 +913,10 @@ def _check_explorer_input(self) -> None: set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens) def _check_trainer_input(self) -> None: + if self.mode in {"explore", "bench", "serve"}: + # no need to check trainer_input in train/bench/serve mode + return + trainer_input = self.buffer.trainer_input experience_buffer = trainer_input.experience_buffer @@ -973,7 +976,7 @@ def _default_storage_path(self, storage_type: StorageType, name: str) -> str: def _check_data_processor(self) -> None: # check input/output buffers in pipelines experience_pipeline = self.data_processor.experience_pipeline - if experience_pipeline is not None: + if experience_pipeline is not None and self.mode in {"explore", "both", "serve"}: if experience_pipeline.save_input and experience_pipeline.input_save_path is None: experience_pipeline.input_save_path = os.path.join( self.buffer.cache_dir, "explorer_output.jsonl" # type: ignore[arg-type] @@ -983,7 +986,7 @@ def _check_data_processor(self) -> None: ) task_pipeline = self.data_processor.task_pipeline - if task_pipeline is not None: + if task_pipeline is not None and self.mode in {"explore", "both"}: if task_pipeline.output is None: if self.mode != "train": task_pipeline.output = self.buffer.explorer_input.tasksets[0] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 038c1dd5f9..acd4aa1b17 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -52,7 +52,9 @@ def __init__(self, config: Config): self.models, self.auxiliary_models = create_inference_models(config) self.experience_pipeline = self._init_experience_pipeline() self.taskset = ( - TasksetScheduler(explorer_state, config) if self.config.mode != "serve" else None + TasksetScheduler(explorer_state, config) + if self.config.mode not in {"bench", "serve"} + else None ) self.scheduler = None self.monitor = MONITOR.get(self.config.monitor.monitor_type)( @@ -406,6 +408,8 @@ async def is_alive(self) -> bool: def _init_experience_pipeline(self) -> ray.actor.ActorHandle: """Init experience pipeline for the explorer.""" + if self.config.mode == "bench": + return None node_id = ray.get_runtime_context().get_node_id() return ( ray.remote(ExperiencePipeline) From b6197d8c99a5592381ba817b14ebc577d66378cc Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 19 Nov 2025 19:58:13 +0800 Subject: [PATCH 6/8] add unittest for kl=0 in step1 --- tests/trainer/trainer_test.py | 2 ++ trinity/common/config.py | 2 +- trinity/explorer/explorer.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index a74dfeb7dc..f87c16429b 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -113,6 +113,8 @@ def test_trainer(self): self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) actor_kl_metrics = parser.metric_list("actor/kl") self.assertTrue(len(actor_kl_metrics) > 0) + actor_kl_loss = parser.metric_values("actor/kl_loss") + self.assertEqual(actor_kl_loss[0], 0.0) critic_kl_metrics = parser.metric_list("critic/kl") self.assertTrue(len(critic_kl_metrics) > 0) response_metrics = parser.metric_list("response_length") diff --git a/trinity/common/config.py b/trinity/common/config.py index 7537084cc1..555ef4c3a7 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -914,7 +914,7 @@ def _check_explorer_input(self) -> None: def _check_trainer_input(self) -> None: if self.mode in {"explore", "bench", "serve"}: - # no need to check trainer_input in train/bench/serve mode + # no need to check trainer_input in explore/bench/serve mode return trainer_input = self.buffer.trainer_input diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index acd4aa1b17..6166e685fa 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -153,7 +153,8 @@ async def prepare(self) -> None: """Preparation before running.""" try: # prepare experience pipeline - await self.experience_pipeline.prepare.remote() + if self.experience_pipeline: + await self.experience_pipeline.prepare.remote() self.logger.info("Experience pipeline is ready.") # make sure all rollout models are ready run_api_ref = [model.run_api_server.remote() for model in self.models] From d68a19bf59e75979566e8cdbaa1753c7f71bcd37 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 20 Nov 2025 12:07:01 +0800 Subject: [PATCH 7/8] fix unittest and bots_reward --- examples/bots/workflow/bots_reward.py | 22 ++++++++++++++++------ trinity/buffer/pipelines/task_pipeline.py | 2 ++ trinity/cli/launcher.py | 3 +-- trinity/common/config.py | 19 ++++++++++++------- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/examples/bots/workflow/bots_reward.py b/examples/bots/workflow/bots_reward.py index 61ea7789ed..e4bdf4b98e 100644 --- a/examples/bots/workflow/bots_reward.py +++ b/examples/bots/workflow/bots_reward.py @@ -1,8 +1,10 @@ # Adapted from Reasoning360: https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py +import concurrent import contextlib import math import re +import resource from math import isclose from typing import Optional, Union @@ -585,17 +587,25 @@ def should_allow_eval(expr: str): # @timeout(timeout_seconds=10) def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): - are_equal = False - try: + def check_equal(): + memory_size = 1024**3 + resource.setrlimit(resource.RLIMIT_AS, (memory_size, memory_size)) + expr = f"({ground_truth_normalized})-({given_normalized})" if should_allow_eval(expr): sympy_diff = _sympy_parse(expr) simplified = sympy.simplify(sympy_diff) if simplified == 0: - are_equal = True - except Exception: - pass - return are_equal + return True + return False + + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: + future = executor.submit(check_equal) + try: + return future.result(timeout=10) + except (concurrent.futures.TimeoutError, Exception): + future.cancel() + return False def split_tuple(expr: str): diff --git a/trinity/buffer/pipelines/task_pipeline.py b/trinity/buffer/pipelines/task_pipeline.py index a9b0017117..9107293902 100644 --- a/trinity/buffer/pipelines/task_pipeline.py +++ b/trinity/buffer/pipelines/task_pipeline.py @@ -5,6 +5,8 @@ def check_and_run_task_pipeline(config: Config) -> Dict: + if config.mode not in {"explore", "train", "both"}: + return {} if config.data_processor.task_pipeline is None: return {} diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index d2d4254894..468ab2df53 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -37,7 +37,6 @@ def bench(config: Config) -> None: def explore(config: Config) -> None: """Run explorer.""" - check_and_run_task_pipeline(config) try: explorer = Explorer.get_actor(config) ray.get(explorer.prepare.remote()) @@ -82,7 +81,6 @@ def both(config: Config) -> None: the latest step. The specific number of experiences may vary for different algorithms and tasks. """ - check_and_run_task_pipeline(config) try: explorer = Explorer.get_actor(config) trainer = Trainer.get_actor(config) @@ -153,6 +151,7 @@ def run_stage(config: Config) -> None: ) pprint(config) try: + check_and_run_task_pipeline(config) MODE_MAP[config.mode](config) finally: if config.monitor.enable_ray_timeline: diff --git a/trinity/common/config.py b/trinity/common/config.py index 555ef4c3a7..65a2148593 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -853,8 +853,8 @@ def _check_interval(self) -> None: ) def _check_explorer_input(self) -> None: - if self.mode in {"train", "bench", "serve"}: - # no need to check explorer_input in train/bench/serve mode + if self.mode == "serve": + # no need to check explorer_input in serve mode return explorer_input = self.buffer.explorer_input @@ -864,7 +864,7 @@ def _check_explorer_input(self) -> None: raise ValueError("Do not support setting `taskset` and `tasksets` simultaneously!") explorer_input.tasksets = [explorer_input.taskset] explorer_input.taskset = None - elif len(explorer_input.tasksets) == 0: + elif self.mode not in {"bench", "train"} and len(explorer_input.tasksets) == 0: raise ValueError("At least one taskset should be provided in explorer_input!") for i, taskset in enumerate(explorer_input.tasksets): @@ -913,8 +913,8 @@ def _check_explorer_input(self) -> None: set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens) def _check_trainer_input(self) -> None: - if self.mode in {"explore", "bench", "serve"}: - # no need to check trainer_input in explore/bench/serve mode + if self.mode in {"bench", "serve"}: + # no need to check trainer_input in bench/serve mode return trainer_input = self.buffer.trainer_input @@ -986,10 +986,15 @@ def _check_data_processor(self) -> None: ) task_pipeline = self.data_processor.task_pipeline - if task_pipeline is not None and self.mode in {"explore", "both"}: + if task_pipeline is not None and self.mode in {"explore", "train", "both"}: if task_pipeline.output is None: if self.mode != "train": - task_pipeline.output = self.buffer.explorer_input.tasksets[0] + if len(self.buffer.explorer_input.tasksets) > 0: + task_pipeline.output = self.buffer.explorer_input.tasksets[0] + else: + raise ValueError( + "At least one taskset should be provided in explorer_input!" + ) elif self.mode == "train" and self.algorithm.algorithm_type in {"dpo", "sft"}: task_pipeline.output = self.buffer.trainer_input.experience_buffer else: From 7d6096141014b744a410445b7e13c92f380730e9 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 20 Nov 2025 12:29:15 +0800 Subject: [PATCH 8/8] fix unittest --- tests/trainer/trainer_test.py | 3 +++ trinity/common/config.py | 10 +++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index f87c16429b..07f35b6219 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -140,6 +140,9 @@ def test_trainer(self): self.config.mode = "bench" self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT self.config.explorer.bench_on_latest_checkpoint = False + self.config.buffer.explorer_input.taskset = None + self.config.buffer.explorer_input.tasksets = [] + self.config.buffer.trainer_input.experience_buffer = None self.config.check_and_update() bench(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) diff --git a/trinity/common/config.py b/trinity/common/config.py index 65a2148593..bbf5fe7ba2 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -853,7 +853,7 @@ def _check_interval(self) -> None: ) def _check_explorer_input(self) -> None: - if self.mode == "serve": + if self.mode in {"train", "serve"}: # no need to check explorer_input in serve mode return @@ -864,11 +864,11 @@ def _check_explorer_input(self) -> None: raise ValueError("Do not support setting `taskset` and `tasksets` simultaneously!") explorer_input.tasksets = [explorer_input.taskset] explorer_input.taskset = None - elif self.mode not in {"bench", "train"} and len(explorer_input.tasksets) == 0: + elif self.mode != "bench" and len(explorer_input.tasksets) == 0: raise ValueError("At least one taskset should be provided in explorer_input!") for i, taskset in enumerate(explorer_input.tasksets): - if self.mode != "train" and not taskset.path: + if not taskset.path: raise ValueError( "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset." ) @@ -913,8 +913,8 @@ def _check_explorer_input(self) -> None: set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens) def _check_trainer_input(self) -> None: - if self.mode in {"bench", "serve"}: - # no need to check trainer_input in bench/serve mode + if self.mode == "bench": + # no need to check trainer_input in bench mode return trainer_input = self.buffer.trainer_input