File tree Expand file tree Collapse file tree 3 files changed +23
-3
lines changed
Expand file tree Collapse file tree 3 files changed +23
-3
lines changed Original file line number Diff line number Diff line change @@ -19,11 +19,11 @@ def bench(config: Config) -> None:
1919 try :
2020 ray .get (explorer .prepare .remote ())
2121 ray .get (explorer .sync_weight .remote ())
22- _ , step = ray .get (explorer .eval .remote ())
23- logger .info ("Evaluation finished." )
22+ _ , step = ray .get (explorer .benchmark .remote ())
23+ logger .info ("Benchmark finished." )
2424 ray .get (explorer .flush_log .remote (step = step ))
2525 except Exception as e :
26- logger .error (f"Evaluation failed: { e } " )
26+ logger .error (f"Benchmark failed: { e } " )
2727 raise e
2828
2929
Original file line number Diff line number Diff line change @@ -107,6 +107,7 @@ class GlobalConfig:
107107 total_epochs : int = 1
108108 batch_size : int = 1
109109 eval_interval : int = 100
110+ eval_on_latest_ckp : bool = True
110111
111112
112113@dataclass
Original file line number Diff line number Diff line change @@ -261,6 +261,25 @@ def wait():
261261 self .monitor .log (log_metrics , step = self .step_num ) # type: ignore
262262 return True , self .step_num
263263
264+ def benchmark (self ) -> Tuple [bool , int ]:
265+ """Benchmark the model checkpoints."""
266+ latest_step = self .step_num
267+
268+ # benchmark on the latest checkpoint
269+ if self .config .global_config .eval_on_latest_ckp :
270+ self .eval ()
271+ return True , self .step_num
272+
273+ # benchmark on all checkoints
274+ for step_num in range (latest_step + 1 ):
275+ path = os .path .join (self .config .model .checkpoint_path , f"global_step_{ step_num } " )
276+ if os .path .isdir (path ) and os .listdir (path ):
277+ self .logger .info (f"{ path } exists." )
278+ self .step_num = step_num
279+ self ._checkpoint_weights_update (step_num = step_num )
280+ self .eval ()
281+ return True , self .step_num
282+
264283 def sync_weight (self ) -> None :
265284 """Synchronize model weights."""
266285 # call this method before training start to load the latest model weights
You can’t perform that action at this time.
0 commit comments