Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ trinity run --config <config_path>



For example, below is the command for fine-tuning Qwen-2.5-1B-Instruct on GSM8k dataset using GRPO algorithm:
For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:

```shell
trinity run --config examples/grpo_gsm8k/gsm8k.yaml
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ trinity run --config <config_path>



For example, below is the command for fine-tuning Qwen-2.5-1B-Instruct on GSM8k dataset using GRPO algorithm:
For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:

```shell
trinity run --config examples/grpo_gsm8k/gsm8k.yaml
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This example describes DPO based on the Qwen-2.5-1.5B-Instruct model and [Human-

### Model Preparation

Download the Qwen-2.5-1B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:

```shell
# Using Modelscope
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This example shows how to run RFT with the Qwen-2.5-1.5B-Instruct model and GSM8

**Model Preparation.**

Download the Qwen-2.5-1B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:

```bash
# Using Modelscope
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"fire",
"flask",
"requests",
"tensorboard",
]

[project.scripts]
Expand Down
2 changes: 1 addition & 1 deletion trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


class FileReader(BufferReader):
"""Reader of the Queue buffer."""
"""Reader of the File buffer."""

def __init__(self, meta: DatasetConfig, config: BufferConfig) -> None:
assert meta.storage_type == StorageType.FILE
Expand Down
5 changes: 2 additions & 3 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def both(config: Config) -> None:
logger.error(e)
logger.error("Evaluation failed.")
raise e

ray.get(explorer.log_finalize.remote(step=explore_iter_num))
ray.get(trainer.log_finalize.remote(step=train_iter_num))
ray.get(explorer.flush_log.remote(step=explore_iter_num))
ray.get(trainer.flush_log.remote(step=train_iter_num))


def activate_data_module(data_workflow_url: str, config_path: str):
Expand Down
22 changes: 10 additions & 12 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ class ExplorerConfig:
@dataclass
class TrainerConfig:
trainer_type: str = "verl"
trainer_data_type: str = "RFT"
trainer_config_path: str = "examples/ppo_countdown/train_countdown.yaml"
trainer_config_path: str = ""
eval_interval: int = 100
enable_preview: bool = True # enable rollout preview in wandb
trainer_config: Any = None
Expand All @@ -185,16 +184,6 @@ class TrainerConfig:
# warmup config
sft_warmup_iteration: int = 0

def __post_init__(self):
if self.trainer_type == "verl":
from trinity.common.verl_config import load_config

if not os.path.isfile(self.trainer_config_path):
raise ValueError(f"Invalid trainer config path: {self.trainer_config_path}")
self.trainer_config = load_config(self.trainer_config_path)
else:
raise ValueError(f"Invalid trainer type: {self.trainer_type}")


@dataclass
class MonitorConfig:
Expand Down Expand Up @@ -285,6 +274,15 @@ def _check_buffer(self) -> None:

def check_and_update(self) -> None:
"""Check and update the config."""
if self.trainer.trainer_type == "verl":
from trinity.common.verl_config import load_config

if not os.path.isfile(self.trainer.trainer_config_path):
raise ValueError(f"Invalid trainer config path: {self.trainer.trainer_config_path}")
self.trainer.trainer_config = load_config(self.trainer.trainer_config_path)
else:
raise ValueError(f"Invalid trainer type: {self.trainer_type}")

# check mode
if self.mode not in ["explore", "train", "both"]:
raise ValueError(f"Invalid mode: {self.mode}")
Expand Down
9 changes: 6 additions & 3 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def explore_step(self) -> Tuple[bool, int]:

def eval(self) -> bool:
"""Evaluation on all evaluation data samples."""
if self.eval_taskset is None:
self.logger.warning("No evaluation data samples. Skip evaluation.")
return True
self.logger.info("Evaluation started.")
st = time.time()
all_metrics = defaultdict(list)
Expand Down Expand Up @@ -248,6 +251,6 @@ def sync_weight(self) -> None:
else: # online weights update
self._online_weights_update()

def log_finalize(self, step: int) -> None:
"""Commit the logging results to wandb"""
self.monitor.log({"dummy_log_explorer": step}, step=step, commit=True)
def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
self.monitor.log({}, step=step, commit=True)
6 changes: 3 additions & 3 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def sync_weight(self) -> None:
if self.config.synchronizer.sync_method == "online":
self.engine.sync_weight()

def log_finalize(self, step: int) -> None:
"""Commit the logging results to wandb"""
self.engine.logger.log({"dummy_log_trainer": step}, step=step, commit=True)
def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
self.engine.logger.log({}, step=step, commit=True)


class TrainEngineWrapper(ABC):
Expand Down
8 changes: 5 additions & 3 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ def train_sft_iteration(self, experiences: Experiences) -> Tuple[bool, int]:
* self.config.trainer.sft_warmup_iteration
):
self.logger.log(
data={"sft_warmup_iteration": self.sft_iter_num}, step=self.global_steps
data={"sft_warmup_iteration": self.sft_iter_num},
step=self.global_steps,
)
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
Expand Down Expand Up @@ -443,11 +444,12 @@ def train_rft_iteration(self, experiences: Experiences) -> Tuple[bool, int]:
compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)
)

# TODO: make a canonical logger that supports various backend
self.logger.log(data=metrics, step=self.global_steps)
if self.config.enable_preview:
self._log_experiences(experiences)

# TODO: make a canonical logger that supports various backend
self.logger.log(data=metrics, step=self.global_steps)

self.global_steps += 1

if self.global_steps >= self.total_training_steps:
Expand Down