Skip to content

Commit a3dfe19

Browse files
authored
Simplify wandb log and update default trainer config (#22)
1 parent ffe22ac commit a3dfe19

File tree

11 files changed

+32
-29
lines changed

11 files changed

+32
-29
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ trinity run --config <config_path>
246246

247247

248248

249-
For example, below is the command for fine-tuning Qwen-2.5-1B-Instruct on GSM8k dataset using GRPO algorithm:
249+
For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
250250

251251
```shell
252252
trinity run --config examples/grpo_gsm8k/gsm8k.yaml

docs/sphinx_doc/source/main.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ trinity run --config <config_path>
226226

227227

228228

229-
For example, below is the command for fine-tuning Qwen-2.5-1B-Instruct on GSM8k dataset using GRPO algorithm:
229+
For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
230230

231231
```shell
232232
trinity run --config examples/grpo_gsm8k/gsm8k.yaml

docs/sphinx_doc/source/tutorial/example_dpo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ This example describes DPO based on the Qwen-2.5-1.5B-Instruct model and [Human-
66

77
### Model Preparation
88

9-
Download the Qwen-2.5-1B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
9+
Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
1010

1111
```shell
1212
# Using Modelscope

docs/sphinx_doc/source/tutorial/example_reasoning_basic.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ This example shows how to run RFT with the Qwen-2.5-1.5B-Instruct model and GSM8
77

88
**Model Preparation.**
99

10-
Download the Qwen-2.5-1B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
10+
Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
1111

1212
```bash
1313
# Using Modelscope

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ dependencies = [
3434
"fire",
3535
"flask",
3636
"requests",
37+
"tensorboard",
3738
]
3839

3940
[project.scripts]

trinity/buffer/reader/file_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
class FileReader(BufferReader):
20-
"""Reader of the Queue buffer."""
20+
"""Reader of the File buffer."""
2121

2222
def __init__(self, meta: DatasetConfig, config: BufferConfig) -> None:
2323
assert meta.storage_type == StorageType.FILE

trinity/cli/launcher.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,8 @@ def both(config: Config) -> None:
102102
logger.error(e)
103103
logger.error("Evaluation failed.")
104104
raise e
105-
106-
ray.get(explorer.log_finalize.remote(step=explore_iter_num))
107-
ray.get(trainer.log_finalize.remote(step=train_iter_num))
105+
ray.get(explorer.flush_log.remote(step=explore_iter_num))
106+
ray.get(trainer.flush_log.remote(step=train_iter_num))
108107

109108

110109
def activate_data_module(data_workflow_url: str, config_path: str):

trinity/common/config.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,7 @@ class ExplorerConfig:
173173
@dataclass
174174
class TrainerConfig:
175175
trainer_type: str = "verl"
176-
trainer_data_type: str = "RFT"
177-
trainer_config_path: str = "examples/ppo_countdown/train_countdown.yaml"
176+
trainer_config_path: str = ""
178177
eval_interval: int = 100
179178
enable_preview: bool = True # enable rollout preview in wandb
180179
trainer_config: Any = None
@@ -185,16 +184,6 @@ class TrainerConfig:
185184
# warmup config
186185
sft_warmup_iteration: int = 0
187186

188-
def __post_init__(self):
189-
if self.trainer_type == "verl":
190-
from trinity.common.verl_config import load_config
191-
192-
if not os.path.isfile(self.trainer_config_path):
193-
raise ValueError(f"Invalid trainer config path: {self.trainer_config_path}")
194-
self.trainer_config = load_config(self.trainer_config_path)
195-
else:
196-
raise ValueError(f"Invalid trainer type: {self.trainer_type}")
197-
198187

199188
@dataclass
200189
class MonitorConfig:
@@ -285,6 +274,15 @@ def _check_buffer(self) -> None:
285274

286275
def check_and_update(self) -> None:
287276
"""Check and update the config."""
277+
if self.trainer.trainer_type == "verl":
278+
from trinity.common.verl_config import load_config
279+
280+
if not os.path.isfile(self.trainer.trainer_config_path):
281+
raise ValueError(f"Invalid trainer config path: {self.trainer.trainer_config_path}")
282+
self.trainer.trainer_config = load_config(self.trainer.trainer_config_path)
283+
else:
284+
raise ValueError(f"Invalid trainer type: {self.trainer_type}")
285+
288286
# check mode
289287
if self.mode not in ["explore", "train", "both"]:
290288
raise ValueError(f"Invalid mode: {self.mode}")

trinity/explorer/explorer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def explore_step(self) -> Tuple[bool, int]:
216216

217217
def eval(self) -> bool:
218218
"""Evaluation on all evaluation data samples."""
219+
if self.eval_taskset is None:
220+
self.logger.warning("No evaluation data samples. Skip evaluation.")
221+
return True
219222
self.logger.info("Evaluation started.")
220223
st = time.time()
221224
all_metrics = defaultdict(list)
@@ -248,6 +251,6 @@ def sync_weight(self) -> None:
248251
else: # online weights update
249252
self._online_weights_update()
250253

251-
def log_finalize(self, step: int) -> None:
252-
"""Commit the logging results to wandb"""
253-
self.monitor.log({"dummy_log_explorer": step}, step=step, commit=True)
254+
def flush_log(self, step: int) -> None:
255+
"""Flush the log of the current step."""
256+
self.monitor.log({}, step=step, commit=True)

trinity/trainer/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ def sync_weight(self) -> None:
105105
if self.config.synchronizer.sync_method == "online":
106106
self.engine.sync_weight()
107107

108-
def log_finalize(self, step: int) -> None:
109-
"""Commit the logging results to wandb"""
110-
self.engine.logger.log({"dummy_log_trainer": step}, step=step, commit=True)
108+
def flush_log(self, step: int) -> None:
109+
"""Flush the log of the current step."""
110+
self.engine.logger.log({}, step=step, commit=True)
111111

112112

113113
class TrainEngineWrapper(ABC):

0 commit comments

Comments
 (0)