33import os
44import time
55from collections import defaultdict
6- from typing import List , Optional
6+ from typing import List , Optional , Tuple
77
88import ray
99import torch
@@ -149,16 +149,20 @@ def get_weight(self, name: str) -> torch.Tensor:
149149
150150 def explore (self ) -> None :
151151 """Explore the entire dataset."""
152- while self .explore_step ():
152+ explore_status , _ = self .explore_step ()
153+ while explore_status :
153154 self .sync_weight ()
154155 self .logger .info ("Explorer finished." )
155156
156- def explore_step (self ) -> bool :
157+ def explore_step (self ) -> Tuple [ bool , int ] :
157158 """Explore for one step.
158159
159160 Different from `explore()` which consumes all tasks in the task set,
160161 `explore_step()` only consume `sync_iteration_interval * batch_size`
161162 number of tasks.
163+ explore_status:
164+ explore_status: whether there are more tasks to explore.
165+ explore_iter_num: the number of explore iterations
162166 """
163167 if self .task_iter is None :
164168 self .task_iter = iter (self .taskset )
@@ -175,7 +179,7 @@ def explore_step(self) -> bool:
175179 self .runner_pool .run_tasks (tasks )
176180 except StopIteration :
177181 self .logger .warning ("No more tasks in the task set. Stop exploring." )
178- return False
182+ return False , self . iteration
179183
180184 # wait for all tasks of this step to finish
181185 while self .runner_pool .has_next ():
@@ -190,7 +194,7 @@ def explore_step(self) -> bool:
190194 self .runner_pool .run_tasks (next (self .task_iter )) # type: ignore
191195 except StopIteration :
192196 self .logger .warning ("No more tasks in the task set. Stop exploring." )
193- return False
197+ return False , self . iteration
194198 else :
195199 for metric_name , metric_value in status .metric .items ():
196200 all_metrics [metric_name ].append (metric_value )
@@ -208,11 +212,11 @@ def explore_step(self) -> bool:
208212 )
209213
210214 self .logger .info ("Explore step finished." )
211- return True
215+ return True , self . iteration
212216
213- def eval (self ) -> bool :
217+ def eval (self , step ) -> bool :
214218 """Evaluation on all evaluation data samples."""
215- self .logger .info ("\n \n Evaluation started.\n \n " )
219+ self .logger .info ("Evaluation started." )
216220 st = time .time ()
217221 all_metrics = defaultdict (list )
218222
@@ -231,11 +235,9 @@ def eval(self) -> bool:
231235 for metric_name , metric_value in status .metric .items ():
232236 all_metrics [metric_name ].append (metric_value )
233237
234- self .logger .info ("Evaluation finished." )
235-
236238 log_metrics = self .monitor .calculate_metrics (all_metrics , prefix = "eval" ) # type: ignore
237239 log_metrics ["eval/total_time" ] = time .time () - st
238- self .monitor .log (log_metrics , step = self . iteration ) # type: ignore
240+ self .monitor .log (log_metrics , step = step ) # type: ignore
239241 return True
240242
241243 def sync_weight (self ) -> None :
0 commit comments