@@ -131,6 +131,12 @@ def __init__(
131131 self ._task_context = {}
132132 self ._available_band_resources = None
133133
134+ # For progress
135+ self ._pre_all_stages_progress = 0.0
136+ self ._pre_all_stages_tile_progress = 0
137+ self ._cur_stage_tile_progress = 0
138+ self ._cur_stage_output_object_refs = []
139+
134140 @classmethod
135141 async def create (
136142 cls ,
@@ -190,7 +196,12 @@ async def execute_subtask_graph(
190196 logger .info ("Stage %s start." , stage_id )
191197 context = self ._task_context
192198 output_meta_object_refs = []
193-
199+ self ._pre_all_stages_tile_progress = (
200+ self ._pre_all_stages_tile_progress + self ._cur_stage_tile_progress
201+ )
202+ self ._cur_stage_tile_progress = (
203+ self ._tile_context .get_all_progress () - self ._pre_all_stages_tile_progress
204+ )
194205 logger .info ("Submitting %s subtasks of stage %s." , len (subtask_graph ), stage_id )
195206 result_meta_keys = {
196207 chunk .key
@@ -219,6 +230,7 @@ async def execute_subtask_graph(
219230 continue
220231 elif output_count == 1 :
221232 output_object_refs = [output_object_refs ]
233+ self ._cur_stage_output_object_refs .extend (output_object_refs )
222234 if output_meta_keys :
223235 meta_object_ref , * output_object_refs = output_object_refs
224236 # TODO(fyrestone): Fetch(not get) meta object here.
@@ -249,6 +261,10 @@ async def execute_subtask_graph(
249261
250262 logger .info ("Waiting for stage %s complete." , stage_id )
251263 ray .wait (output_object_refs , fetch_local = False )
264+ # Just use `self._cur_stage_tile_progress` as current stage progress
265+ # because current stage is finished, its progress is 1.
266+ self ._pre_all_stages_progress += self ._cur_stage_tile_progress
267+ self ._cur_stage_output_object_refs .clear ()
252268 logger .info ("Stage %s is complete." , stage_id )
253269 return chunk_to_meta
254270
@@ -293,7 +309,19 @@ async def get_available_band_resources(self) -> Dict[BandType, Resource]:
293309
294310 async def get_progress (self ) -> float :
295311 """Get the execution progress."""
296- return 1
312+ stage_progress = 0.0
313+ total = len (self ._cur_stage_output_object_refs )
314+ if total > 0 :
315+ finished_objects , _ = ray .wait (
316+ self ._cur_stage_output_object_refs ,
317+ num_returns = total ,
318+ timeout = 0.1 ,
319+ fetch_local = False ,
320+ )
321+ stage_progress = (
322+ len (finished_objects ) / total * self ._cur_stage_tile_progress
323+ )
324+ return self ._pre_all_stages_progress + stage_progress
297325
298326 async def cancel (self ):
299327 """Cancel execution."""
0 commit comments