Skip to content

Commit 94c3460

Browse files
committed
add benchmark mode
1 parent 2086bec commit 94c3460

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

trinity/cli/launcher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ def bench(config: Config) -> None:
1919
try:
2020
ray.get(explorer.prepare.remote())
2121
ray.get(explorer.sync_weight.remote())
22-
_, step = ray.get(explorer.eval.remote())
23-
logger.info("Evaluation finished.")
22+
_, step = ray.get(explorer.benchmark.remote())
23+
logger.info("Benchmark finished.")
2424
ray.get(explorer.flush_log.remote(step=step))
2525
except Exception as e:
26-
logger.error(f"Evaluation failed: {e}")
26+
logger.error(f"Benchmark failed: {e}")
2727
raise e
2828

2929

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class GlobalConfig:
107107
total_epochs: int = 1
108108
batch_size: int = 1
109109
eval_interval: int = 100
110+
eval_on_latest_ckp: bool = True
110111

111112

112113
@dataclass

trinity/explorer/explorer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,25 @@ def wait():
261261
self.monitor.log(log_metrics, step=self.step_num) # type: ignore
262262
return True, self.step_num
263263

264+
def benchmark(self) -> Tuple[bool, int]:
265+
"""Benchmark the model checkpoints."""
266+
latest_step = self.step_num
267+
268+
# benchmark on the latest checkpoint
269+
if self.config.global_config.eval_on_latest_ckp:
270+
self.eval()
271+
return True, self.step_num
272+
273+
# benchmark on all checkoints
274+
for step_num in range(latest_step + 1):
275+
path = os.path.join(self.config.model.checkpoint_path, f"global_step_{step_num}")
276+
if os.path.isdir(path) and os.listdir(path):
277+
self.logger.info(f"{path} exists.")
278+
self.step_num = step_num
279+
self._checkpoint_weights_update(step_num=step_num)
280+
self.eval()
281+
return True, self.step_num
282+
264283
def sync_weight(self) -> None:
265284
"""Synchronize model weights."""
266285
# call this method before training start to load the latest model weights

0 commit comments

Comments
 (0)