diff --git a/examples/bots/workflow/bots_reward.py b/examples/bots/workflow/bots_reward.py index 61ea7789ed..e4bdf4b98e 100644 --- a/examples/bots/workflow/bots_reward.py +++ b/examples/bots/workflow/bots_reward.py @@ -1,8 +1,10 @@ # Adapted from Reasoning360: https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py +import concurrent import contextlib import math import re +import resource from math import isclose from typing import Optional, Union @@ -585,17 +587,25 @@ def should_allow_eval(expr: str): # @timeout(timeout_seconds=10) def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): - are_equal = False - try: + def check_equal(): + memory_size = 1024**3 + resource.setrlimit(resource.RLIMIT_AS, (memory_size, memory_size)) + expr = f"({ground_truth_normalized})-({given_normalized})" if should_allow_eval(expr): sympy_diff = _sympy_parse(expr) simplified = sympy.simplify(sympy_diff) if simplified == 0: - are_equal = True - except Exception: - pass - return are_equal + return True + return False + + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: + future = executor.submit(check_equal) + try: + return future.result(timeout=10) + except (concurrent.futures.TimeoutError, Exception): + future.cancel() + return False def split_tuple(expr: str): diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index a74dfeb7dc..07f35b6219 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -113,6 +113,8 @@ def test_trainer(self): self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) actor_kl_metrics = parser.metric_list("actor/kl") self.assertTrue(len(actor_kl_metrics) > 0) + actor_kl_loss = parser.metric_values("actor/kl_loss") + self.assertEqual(actor_kl_loss[0], 0.0) critic_kl_metrics = parser.metric_list("critic/kl") self.assertTrue(len(critic_kl_metrics) > 0) response_metrics = parser.metric_list("response_length") @@ -138,6 +140,9 @@ def test_trainer(self): self.config.mode = "bench" self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT self.config.explorer.bench_on_latest_checkpoint = False + self.config.buffer.explorer_input.taskset = None + self.config.buffer.explorer_input.tasksets = [] + self.config.buffer.trainer_input.experience_buffer = None self.config.check_and_update() bench(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) diff --git a/trinity/buffer/pipelines/task_pipeline.py b/trinity/buffer/pipelines/task_pipeline.py index a9b0017117..9107293902 100644 --- a/trinity/buffer/pipelines/task_pipeline.py +++ b/trinity/buffer/pipelines/task_pipeline.py @@ -5,6 +5,8 @@ def check_and_run_task_pipeline(config: Config) -> Dict: + if config.mode not in {"explore", "train", "both"}: + return {} if config.data_processor.task_pipeline is None: return {} diff --git a/trinity/common/config.py b/trinity/common/config.py index 3f2c44eb53..bbf5fe7ba2 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -853,8 +853,8 @@ def _check_interval(self) -> None: ) def _check_explorer_input(self) -> None: - if self.mode == "train": - # no need to check explorer_input in train mode + if self.mode in {"train", "serve"}: + # no need to check explorer_input in serve mode return explorer_input = self.buffer.explorer_input @@ -864,12 +864,11 @@ def _check_explorer_input(self) -> None: raise ValueError("Do not support setting `taskset` and `tasksets` simultaneously!") explorer_input.tasksets = [explorer_input.taskset] explorer_input.taskset = None - elif len(explorer_input.tasksets) == 0: + elif self.mode != "bench" and len(explorer_input.tasksets) == 0: raise ValueError("At least one taskset should be provided in explorer_input!") - tasksets = explorer_input.tasksets - for i, taskset in enumerate(tasksets): - if self.mode != "train" and not taskset.path: + for i, taskset in enumerate(explorer_input.tasksets): + if not taskset.path: raise ValueError( "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset." ) @@ -914,6 +913,10 @@ def _check_explorer_input(self) -> None: set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens) def _check_trainer_input(self) -> None: + if self.mode == "bench": + # no need to check trainer_input in bench mode + return + trainer_input = self.buffer.trainer_input experience_buffer = trainer_input.experience_buffer @@ -973,7 +976,7 @@ def _default_storage_path(self, storage_type: StorageType, name: str) -> str: def _check_data_processor(self) -> None: # check input/output buffers in pipelines experience_pipeline = self.data_processor.experience_pipeline - if experience_pipeline is not None: + if experience_pipeline is not None and self.mode in {"explore", "both", "serve"}: if experience_pipeline.save_input and experience_pipeline.input_save_path is None: experience_pipeline.input_save_path = os.path.join( self.buffer.cache_dir, "explorer_output.jsonl" # type: ignore[arg-type] @@ -983,10 +986,15 @@ def _check_data_processor(self) -> None: ) task_pipeline = self.data_processor.task_pipeline - if task_pipeline is not None: + if task_pipeline is not None and self.mode in {"explore", "train", "both"}: if task_pipeline.output is None: if self.mode != "train": - task_pipeline.output = self.buffer.explorer_input.tasksets[0] + if len(self.buffer.explorer_input.tasksets) > 0: + task_pipeline.output = self.buffer.explorer_input.tasksets[0] + else: + raise ValueError( + "At least one taskset should be provided in explorer_input!" + ) elif self.mode == "train" and self.algorithm.algorithm_type in {"dpo", "sft"}: task_pipeline.output = self.buffer.trainer_input.experience_buffer else: diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index eba8739026..eddcbba727 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -85,6 +85,7 @@ class FSDPConfig: wrap_policy: WrapPolicy = field(default_factory=WrapPolicy) fsdp_size: int = -1 forward_prefetch: bool = False + model_dtype: Optional[str] = None @dataclass @@ -163,8 +164,6 @@ class Actor: clip_ratio_high: Optional[float] = None entropy_coeff: float = 0.001 use_kl_loss: bool = False - kl_loss_coef: float = 0.0 - kl_loss_type: str = "low_var_kl" @dataclass diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 038c1dd5f9..6166e685fa 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -52,7 +52,9 @@ def __init__(self, config: Config): self.models, self.auxiliary_models = create_inference_models(config) self.experience_pipeline = self._init_experience_pipeline() self.taskset = ( - TasksetScheduler(explorer_state, config) if self.config.mode != "serve" else None + TasksetScheduler(explorer_state, config) + if self.config.mode not in {"bench", "serve"} + else None ) self.scheduler = None self.monitor = MONITOR.get(self.config.monitor.monitor_type)( @@ -151,7 +153,8 @@ async def prepare(self) -> None: """Preparation before running.""" try: # prepare experience pipeline - await self.experience_pipeline.prepare.remote() + if self.experience_pipeline: + await self.experience_pipeline.prepare.remote() self.logger.info("Experience pipeline is ready.") # make sure all rollout models are ready run_api_ref = [model.run_api_server.remote() for model in self.models] @@ -406,6 +409,8 @@ async def is_alive(self) -> bool: def _init_experience_pipeline(self) -> ray.actor.ActorHandle: """Init experience pipeline for the explorer.""" + if self.config.mode == "bench": + return None node_id = ray.get_runtime_context().get_node_id() return ( ray.remote(ExperiencePipeline) diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 94d11ae7f8..e5adb86f1c 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -44,7 +44,10 @@ def __init__(self, config: Config, module_ref: ray.actor.ActorHandle): self._modules = {module_ref} self._modules_lock = asyncio.Lock() asyncio.create_task(self._check_modules()) - if self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: + if ( + self.config.mode != "bench" + and self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT + ): asyncio.create_task(self._find_latest_state_dict()) async def add_module(self, module_ref: ray.actor.ActorHandle) -> None: diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index fedc54bb55..3e508c9124 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -168,7 +168,10 @@ def __init__(self, config: DictConfig, role: str): self.config.actor.ppo_micro_batch_size ) - if self.config.actor.ppo_micro_batch_size_per_gpu is not None: + if ( + not self.config.actor.use_dynamic_bsz + and self.config.actor.ppo_micro_batch_size_per_gpu is not None + ): assert ( self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu @@ -181,7 +184,11 @@ def __init__(self, config: DictConfig, role: str): ), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" # normalize ref config - if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: + if ( + self._is_ref + and not self.config.ref.log_prob_use_dynamic_bsz + and self.config.ref.log_prob_micro_batch_size is not None + ): self.config.ref.log_prob_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) @@ -246,7 +253,7 @@ def _build_model_optimizer( # noqa: C901 else: self.tokenizer.chat_template = self.config.model.custom_chat_template - torch_dtype = fsdp_config.get("model_dtype", None) + torch_dtype = fsdp_config.model_dtype if torch_dtype is None: torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 else: @@ -326,9 +333,6 @@ def _build_model_optimizer( # noqa: C901 fused_kernels_backend=fused_kernels_backend, ) - # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 - actor_module.to(torch_dtype) - if enable_gradient_checkpointing: actor_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} @@ -971,7 +975,7 @@ def __init__(self, config): self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size - if self.config.ppo_micro_batch_size_per_gpu is not None: + if not self.config.use_dynamic_bsz and self.config.ppo_micro_batch_size_per_gpu is not None: assert ( self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" @@ -1020,7 +1024,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 if self.rank == 0: print(f"Critic overriding config {override_config_kwargs}") - torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") + torch_dtype = self.config.model.fsdp_config.model_dtype or "fp32" torch_dtype = PrecisionType.to_dtype(torch_dtype) from transformers import AutoConfig @@ -1060,9 +1064,6 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 ulysses_sp_size=self.ulysses_sequence_parallel_size, ) - # some parameters may not in torch_dtype - critic_module.to(torch_dtype) - if config.model.get("enable_gradient_checkpointing", False): critic_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index b65b943782..f24c8aa4ef 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -233,44 +233,50 @@ def save_checkpoint( # noqa: C901 json.dump(transformer_config_dict, f, indent=2) if self.should_save_hf_model or save_as_hf: - # wait for everyone to dump to local - state_dict = self.weight_saver( - self.model, - self.hf_config, - dtype=self.param_dtype, - is_value_model=self.is_value_model, - tie_word_embeddings=self.share_embeddings_and_output_weights, - ) + try: + # wait for everyone to dump to local + state_dict = self.weight_saver( + self.model, + self.hf_config, + dtype=self.param_dtype, + is_value_model=self.is_value_model, + tie_word_embeddings=self.share_embeddings_and_output_weights, + ) - torch.distributed.barrier() - if self.rank == 0: - # TODO: async save or use mbridge to save hf model - hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) - import warnings + torch.distributed.barrier() + if self.rank == 0: + # TODO: async save or use mbridge to save hf model + hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) + import warnings - from accelerate import init_empty_weights + from accelerate import init_empty_weights - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - if "mistral7b-rm" in self.config.model.path: - from transformers import MistralForSequenceClassification + with init_empty_weights(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + if "mistral7b-rm" in self.config.model.path: + from transformers import MistralForSequenceClassification - model = MistralForSequenceClassification.from_pretrained( - self.config.model.path - ) # use score head instead of lm_head - state_dict["score.weight"] = state_dict["score.weight"] - else: - from transformers import AutoModelForCausalLM + model = MistralForSequenceClassification.from_pretrained( + self.config.model.path + ) # use score head instead of lm_head + state_dict["score.weight"] = state_dict["score.weight"] + else: + from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_pretrained( - self.config.model.path, torch_dtype="auto" - ) - model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) - log_with_rank( - f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, + model = AutoModelForCausalLM.from_pretrained( + self.config.model.path, torch_dtype="auto" + ) + model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) + log_with_rank( + f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + except Exception: + logger.error( + f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it.", + exc_info=True, ) ray.get(