diff --git a/README.md b/README.md index 1f6f5bdd45..09988d6d81 100644 --- a/README.md +++ b/README.md @@ -246,7 +246,7 @@ trinity run --config -For example, below is the command for fine-tuning Qwen-2.5-1B-Instruct on GSM8k dataset using GRPO algorithm: +For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm: ```shell trinity run --config examples/grpo_gsm8k/gsm8k.yaml diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md index c277e9b116..a7e6684219 100644 --- a/docs/sphinx_doc/source/main.md +++ b/docs/sphinx_doc/source/main.md @@ -226,7 +226,7 @@ trinity run --config -For example, below is the command for fine-tuning Qwen-2.5-1B-Instruct on GSM8k dataset using GRPO algorithm: +For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm: ```shell trinity run --config examples/grpo_gsm8k/gsm8k.yaml diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md index 448cfd67fe..26451ab982 100644 --- a/docs/sphinx_doc/source/tutorial/example_dpo.md +++ b/docs/sphinx_doc/source/tutorial/example_dpo.md @@ -6,7 +6,7 @@ This example describes DPO based on the Qwen-2.5-1.5B-Instruct model and [Human- ### Model Preparation -Download the Qwen-2.5-1B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`: +Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`: ```shell # Using Modelscope diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md index 6893528ad4..884efd6d21 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md @@ -7,7 +7,7 @@ This example shows how to run RFT with the Qwen-2.5-1.5B-Instruct model and GSM8 **Model Preparation.** -Download the Qwen-2.5-1B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`: +Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`: ```bash # Using Modelscope diff --git a/pyproject.toml b/pyproject.toml index 3cb55daafb..717a88a00c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "fire", "flask", "requests", + "tensorboard", ] [project.scripts] diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index f2e25cee31..dbb9565b30 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -17,7 +17,7 @@ class FileReader(BufferReader): - """Reader of the Queue buffer.""" + """Reader of the File buffer.""" def __init__(self, meta: DatasetConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.FILE diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index ad65ab1223..59e0387172 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -102,9 +102,8 @@ def both(config: Config) -> None: logger.error(e) logger.error("Evaluation failed.") raise e - - ray.get(explorer.log_finalize.remote(step=explore_iter_num)) - ray.get(trainer.log_finalize.remote(step=train_iter_num)) + ray.get(explorer.flush_log.remote(step=explore_iter_num)) + ray.get(trainer.flush_log.remote(step=train_iter_num)) def activate_data_module(data_workflow_url: str, config_path: str): diff --git a/trinity/common/config.py b/trinity/common/config.py index 1e0b0c2509..46551e55d3 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -173,8 +173,7 @@ class ExplorerConfig: @dataclass class TrainerConfig: trainer_type: str = "verl" - trainer_data_type: str = "RFT" - trainer_config_path: str = "examples/ppo_countdown/train_countdown.yaml" + trainer_config_path: str = "" eval_interval: int = 100 enable_preview: bool = True # enable rollout preview in wandb trainer_config: Any = None @@ -185,16 +184,6 @@ class TrainerConfig: # warmup config sft_warmup_iteration: int = 0 - def __post_init__(self): - if self.trainer_type == "verl": - from trinity.common.verl_config import load_config - - if not os.path.isfile(self.trainer_config_path): - raise ValueError(f"Invalid trainer config path: {self.trainer_config_path}") - self.trainer_config = load_config(self.trainer_config_path) - else: - raise ValueError(f"Invalid trainer type: {self.trainer_type}") - @dataclass class MonitorConfig: @@ -285,6 +274,15 @@ def _check_buffer(self) -> None: def check_and_update(self) -> None: """Check and update the config.""" + if self.trainer.trainer_type == "verl": + from trinity.common.verl_config import load_config + + if not os.path.isfile(self.trainer.trainer_config_path): + raise ValueError(f"Invalid trainer config path: {self.trainer.trainer_config_path}") + self.trainer.trainer_config = load_config(self.trainer.trainer_config_path) + else: + raise ValueError(f"Invalid trainer type: {self.trainer_type}") + # check mode if self.mode not in ["explore", "train", "both"]: raise ValueError(f"Invalid mode: {self.mode}") diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index f43d9c753d..ecf26c6366 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -216,6 +216,9 @@ def explore_step(self) -> Tuple[bool, int]: def eval(self) -> bool: """Evaluation on all evaluation data samples.""" + if self.eval_taskset is None: + self.logger.warning("No evaluation data samples. Skip evaluation.") + return True self.logger.info("Evaluation started.") st = time.time() all_metrics = defaultdict(list) @@ -248,6 +251,6 @@ def sync_weight(self) -> None: else: # online weights update self._online_weights_update() - def log_finalize(self, step: int) -> None: - """Commit the logging results to wandb""" - self.monitor.log({"dummy_log_explorer": step}, step=step, commit=True) + def flush_log(self, step: int) -> None: + """Flush the log of the current step.""" + self.monitor.log({}, step=step, commit=True) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index f35c55af0a..925baa9400 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -105,9 +105,9 @@ def sync_weight(self) -> None: if self.config.synchronizer.sync_method == "online": self.engine.sync_weight() - def log_finalize(self, step: int) -> None: - """Commit the logging results to wandb""" - self.engine.logger.log({"dummy_log_trainer": step}, step=step, commit=True) + def flush_log(self, step: int) -> None: + """Flush the log of the current step.""" + self.engine.logger.log({}, step=step, commit=True) class TrainEngineWrapper(ABC): diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 89f649d3b6..ea7c632b35 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -306,7 +306,8 @@ def train_sft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: * self.config.trainer.sft_warmup_iteration ): self.logger.log( - data={"sft_warmup_iteration": self.sft_iter_num}, step=self.global_steps + data={"sft_warmup_iteration": self.sft_iter_num}, + step=self.global_steps, ) with _timer("save_checkpoint", timing_raw): self._save_checkpoint() @@ -443,11 +444,12 @@ def train_rft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus) ) - # TODO: make a canonical logger that supports various backend - self.logger.log(data=metrics, step=self.global_steps) if self.config.enable_preview: self._log_experiences(experiences) + # TODO: make a canonical logger that supports various backend + self.logger.log(data=metrics, step=self.global_steps) + self.global_steps += 1 if self.global_steps >= self.total_training_steps: