diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 27f95c1383..1aaebe3aeb 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -23,6 +23,7 @@ class ActorModel: override_config: Dict[str, Any] = field(default_factory=dict) enable_gradient_checkpointing: bool = True use_remove_padding: bool = False + use_fused_kernels: bool = False @dataclass diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 39171d5561..7c0db027a0 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -6,7 +6,7 @@ from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task EXAMPLE_PROMPT = """ -Observation: +Observation: -= Welcome to TextWorld, ALFRED! =- You are in the middle of a room. Looking quickly around you, you see a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a countertop 1, a garbagecan 1, a handtowelholder 2, a handtowelholder 1, a sinkbasin 2, a sinkbasin 1, a toilet 1, a toiletpaperhanger 1, and a towelholder 1. @@ -88,7 +88,7 @@ def parse_action(response): action = response.split("")[1].split("")[0].strip() return action except Exception as e: - print("Error parsing action:", e) + print(f"Error parsing action: {e}, response = {response}") return "" diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 1f2f34c8ac..d7e07406a2 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -182,6 +182,7 @@ async def explore(self) -> str: self.eval_explore_step_num = None while True: try: + self.logger.info(f"Explore step {self.explore_step_num + 1} started.") if ( self.eval_explore_step_num is None and self.explore_step_num % self.config.explorer.eval_interval == 0 diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index de4305a9cc..b9ba995985 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -64,7 +64,7 @@ def maintain_session_state(self): def maintain_list_state(prefix, key_list): last_idx, del_num = 0, 0 - for idx in range(st.session_state[f"_{prefix}_num"]): + for idx in range(st.session_state[f"_{prefix}s_num"]): if st.session_state.get(f"{prefix}_{idx}_del_flag", False): del_num += 1 continue @@ -73,7 +73,7 @@ def maintain_list_state(prefix, key_list): last_full_key = f"{prefix}_{last_idx}_{key}" st.session_state[last_full_key] = st.session_state[full_key] last_idx += 1 - st.session_state[f"_{prefix}_num"] -= del_num + st.session_state[f"_{prefix}s_num"] -= del_num self.eval_dataset_keys = [ "name", @@ -86,7 +86,7 @@ def maintain_list_state(prefix, key_list): "logprobs", "n", ] - maintain_list_state("eval_tasksets", self.eval_dataset_keys) + maintain_list_state("eval_taskset", self.eval_dataset_keys) self.inference_model_keys = [ "model_path", @@ -103,7 +103,7 @@ def maintain_list_state(prefix, key_list): "enable_thinking", "enable_openai_api", ] - maintain_list_state("auxiliary_models", self.inference_model_keys) + maintain_list_state("auxiliary_model", self.inference_model_keys) def get_configs(self, *config_names: str, columns_spec: List[int] = None): CONFIG_GENERATORS.get_configs(*config_names, columns_spec=columns_spec) @@ -356,7 +356,6 @@ def _generate_verl_config(self): ], "use_dynamic_bsz": use_dynamic_bsz, "ppo_max_token_len_per_gpu": ppo_max_token_len_per_gpu, - "kl_loss_type": st.session_state["actor_kl_loss_type"], "ppo_epochs": st.session_state["ppo_epochs"], "shuffle": False, "ulysses_sequence_parallel_size": st.session_state[ diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py index 9b3e5f3ea9..95630c3305 100644 --- a/trinity/manager/config_registry/trainer_config_manager.py +++ b/trinity/manager/config_registry/trainer_config_manager.py @@ -265,15 +265,6 @@ def set_actor_lr_warmup_steps_ratio(**kwargs): ) -@CONFIG_GENERATORS.register_config(default_value="low_var_kl") -def set_actor_kl_loss_type(**kwargs): - st.selectbox( - "KL Loss Type", - ["kl", "abs", "mse", "low_var_kl"], - **kwargs, - ) - - @CONFIG_GENERATORS.register_config(default_value=["model", "hf_model", "optimizer", "extra"]) def set_actor_checkpoint(**kwargs): st.multiselect( diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 91f681e47f..646bfdac4b 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -57,13 +57,18 @@ def need_sync(self) -> bool: def sync_weight(self) -> None: """Sync the model weight.""" if self.config.synchronizer.sync_method == SyncMethod.NCCL: + self.logger.info( + f"Trainer synchronizing weights at step {self.engine.train_step_num} starting.." + ) if self.explorer_ref is None: self.explorer_ref = ray.get_actor(self.config.explorer.name) explorer_status = ray.get(self.explorer_ref.running_status.remote()) if explorer_status == RunningStatus.STOPPED: self.logger.warning("Explorer has already stopped. Skipping sync weight.") return - self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.") + self.logger.info( + f"Trainer synchronizing weights at step {self.engine.train_step_num} end." + ) self.engine.sync_weight() def flush_log(self, step: int) -> None: diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 7c789a98d2..e59914203a 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -38,6 +38,7 @@ from trinity.common.config import Config from trinity.common.experience import Experiences from trinity.trainer.trainer import TrainEngineWrapper +from trinity.utils.log import get_logger from trinity.utils.monitor import MONITOR @@ -146,13 +147,14 @@ def __init__( ray_worker_group_cls, ) self.init_workers() - self.logger = MONITOR.get(global_config.monitor.monitor_type)( + self.monitor = MONITOR.get(global_config.monitor.monitor_type)( project=config.trainer.project_name, name=config.trainer.experiment_name, role=global_config.trainer.name, config=global_config, ) self.reset_experiences_example_table() + self.logger = get_logger(__name__) def _validate_config(self): # TODO algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type) @@ -276,7 +278,7 @@ def prepare(self): if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): val_metrics = self._validate() pprint(f"Initial validation metrics: {val_metrics}") - self.logger.log(data=val_metrics, step=self.global_steps) + self.monitor.log(data=val_metrics, step=self.global_steps) if self.config.trainer.get("val_only", False): return @@ -286,6 +288,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize def train_step(self) -> bool: # noqa C901 + self.logger.info(f"Training at step {self.global_steps + 1} started.") metrics = {} try: batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1) @@ -294,6 +297,7 @@ def train_step(self) -> bool: # noqa C901 print("No more data to train. Stop training.") return False self.global_steps += 1 + self.logger.info(f"Sampling at step {self.global_steps} done.") timing_raw = {} algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps) algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type) @@ -356,8 +360,10 @@ def train_step(self) -> bool: # noqa C901 self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0 ): + self.logger.info(f"Saving at step {self.global_steps}.") with _timer("save_checkpoint", timing_raw): self._save_checkpoint() + self.logger.info(f"Saved at step {self.global_steps}.") # collect metrics if self.algorithm.use_advantage: # TODO @@ -372,7 +378,7 @@ def train_step(self) -> bool: # noqa C901 self._log_experiences(exp_samples) # TODO: make a canonical logger that supports various backend - self.logger.log(data=metrics, step=self.global_steps) + self.monitor.log(data=metrics, step=self.global_steps) train_status = self.global_steps < self.total_training_steps if not train_status or self.algorithm_manager.need_save(self.global_steps): @@ -380,8 +386,11 @@ def train_step(self) -> bool: # noqa C901 self.config.trainer.save_freq == 0 or self.global_steps % self.config.trainer.save_freq != 0 ): + self.logger.info(f"Saving at step {self.global_steps}.") with _timer("save_checkpoint", timing_raw): self._save_checkpoint() + self.logger.info(f"Saved at step {self.global_steps}.") + self.logger.info(f"Training at step {self.global_steps} finished.") return train_status def _log_single_experience( @@ -412,7 +421,7 @@ def _log_single_experience( def _log_experiences(self, samples: List[Dict]) -> None: self.sample_exps_to_log.extend(samples) if self.global_steps % self.config.trainer.sync_freq == 0: - self.logger.log_table( + self.monitor.log_table( "rollout_examples", pd.DataFrame(self.sample_exps_to_log), self.global_steps ) self.reset_experiences_example_table() diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 965fb7e4df..e83df10b8f 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -81,6 +81,7 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: """Log metrics.""" for key in data: self.logger.add_scalar(key, data[key], step) + self.console_logger.info(f"Step {step}: {data}") def close(self) -> None: self.logger.close()