1313# limitations under the License.
1414
1515import dataclasses
16+ import functools
1617from typing import (
1718 Callable ,
1819 Dict ,
3536
3637
3738tile_gen_type = Generator [List [ChunkType ], List [ChunkType ], List [TileableType ]]
39+ DEFAULT_UPDATED_PROGRESS = 0.4
3840
3941
4042@dataclasses .dataclass
@@ -44,14 +46,84 @@ class _TileableHandler:
4446 last_need_processes : List [EntityType ] = None
4547
4648
49+ @dataclasses .dataclass
50+ class _TileableTileInfo :
51+ curr_iter : int
52+ # incremental progress for this iteration
53+ tile_progress : float
54+ # newly generated chunks by a tileable in this iteration
55+ generated_chunks : List [ChunkType ] = dataclasses .field (default_factory = list )
56+
57+
58+ class TileContext (Dict [TileableType , TileableType ]):
59+ _tileables = Set [TileableType ]
60+ _tileable_to_progress : Dict [TileableType , float ]
61+ _tileable_to_tile_infos : Dict [TileableType , List [_TileableTileInfo ]]
62+
63+ def __init__ (self , * args , ** kw ):
64+ super ().__init__ (* args , ** kw )
65+ self ._tileables = None
66+ self ._tileable_to_progress = dict ()
67+ self ._tileable_to_tile_infos = dict ()
68+
69+ def set_tileables (self , tileables : Set [TileableType ]):
70+ self ._tileables = tileables
71+
72+ def __setitem__ (self , key , value ):
73+ self ._tileable_to_progress .pop (key , None )
74+ return super ().__setitem__ (key , value )
75+
76+ def set_progress (self , tileable : TileableType , progress : float ):
77+ assert 0.0 <= progress <= 1.0
78+ last_progress = self ._tileable_to_progress .get (tileable , 0.0 )
79+ self ._tileable_to_progress [tileable ] = max (progress , last_progress )
80+
81+ def get_progress (self , tileable : TileableType ) -> float :
82+ if tileable in self :
83+ return 1.0
84+ else :
85+ return self ._tileable_to_progress .get (tileable , 0.0 )
86+
87+ def get_all_progress (self ) -> float :
88+ return sum (self .get_progress (t ) for t in self ._tileables ) / len (self ._tileables )
89+
90+ def record_tileable_tile_info (
91+ self , tileable : TileableType , curr_iter : int , generated_chunks : List [ChunkType ]
92+ ):
93+ if tileable not in self ._tileable_to_tile_infos :
94+ self ._tileable_to_tile_infos [tileable ] = []
95+ prev_progress = sum (
96+ info .tile_progress for info in self ._tileable_to_tile_infos [tileable ]
97+ )
98+ curr_progress = self .get_progress (tileable )
99+ infos = self ._tileable_to_tile_infos [tileable ]
100+ infos .append (
101+ _TileableTileInfo (
102+ curr_iter = curr_iter ,
103+ tile_progress = curr_progress - prev_progress ,
104+ generated_chunks = generated_chunks ,
105+ )
106+ )
107+
108+ def get_tileable_tile_infos (self ) -> Dict [TileableType , List [_TileableTileInfo ]]:
109+ return {t : self ._tileable_to_tile_infos .get (t , list ()) for t in self ._tileables }
110+
111+
112+ @dataclasses .dataclass
113+ class TileStatus :
114+ entities : List [EntityType ] = None
115+ progress : float = None
116+
117+
47118class Tiler :
119+ _cur_iter : int
48120 _cur_chunk_graph : Optional [ChunkGraph ]
49121 _tileable_handlers : Iterable [_TileableHandler ]
50122
51123 def __init__ (
52124 self ,
53125 tileable_graph : TileableGraph ,
54- tile_context : Dict [ TileableType , TileableType ] ,
126+ tile_context : TileContext ,
55127 processed_chunks : Set [ChunkType ],
56128 chunk_to_fetch : Dict [ChunkType , ChunkType ],
57129 add_nodes : Callable ,
@@ -60,13 +132,31 @@ def __init__(
60132 self ._tile_context = tile_context
61133 self ._processed_chunks = processed_chunks
62134 self ._chunk_to_fetch = chunk_to_fetch
63- self ._add_nodes = add_nodes
135+ self ._add_nodes = self ._wrap_add_nodes (add_nodes )
136+ self ._curr_iter = 0
64137 self ._cur_chunk_graph = None
65138 self ._tileable_handlers = (
66139 _TileableHandler (tileable , self ._tile_handler (tileable ))
67140 for tileable in tileable_graph .topological_iter ()
68141 )
69142
143+ def _wrap_add_nodes (self , add_nodes : Callable ):
144+ @functools .wraps (add_nodes )
145+ def inner (
146+ chunk_graph : ChunkGraph ,
147+ chunks : List [ChunkType ],
148+ visited : Set [ChunkType ],
149+ tileable : TileableType ,
150+ ):
151+ prev_chunks = set (chunk_graph )
152+ add_nodes (chunk_graph , chunks , visited )
153+ new_chunks = set (chunk_graph )
154+ self ._tile_context .record_tileable_tile_info (
155+ tileable , self ._curr_iter , list (new_chunks - prev_chunks )
156+ )
157+
158+ return inner
159+
70160 @staticmethod
71161 def _get_data (entity : EntityType ):
72162 return entity .data if hasattr (entity , "data" ) else entity
@@ -119,6 +209,17 @@ def _tile(
119209 ):
120210 try :
121211 need_process = next (tile_handler )
212+
213+ if isinstance (need_process , TileStatus ):
214+ # process tile that returns progress
215+ self ._tile_context .set_progress (tileable , need_process .progress )
216+ need_process = need_process .entities
217+ else :
218+ # if progress not specified, we just update 0.4 * rest progress
219+ progress = self ._tile_context .get_progress (tileable )
220+ new_progress = progress + (1.0 - progress ) * DEFAULT_UPDATED_PROGRESS
221+ self ._tile_context .set_progress (tileable , new_progress )
222+
122223 chunks = []
123224 if need_process is not None :
124225 for t in need_process :
@@ -127,7 +228,7 @@ def _tile(
127228 elif isinstance (t , TILEABLE_TYPE ):
128229 to_update_tileables .append (self ._get_data (t ))
129230 # not finished yet
130- self ._add_nodes (chunk_graph , chunks .copy (), visited )
231+ self ._add_nodes (chunk_graph , chunks .copy (), visited , tileable )
131232 next_tileable_handlers .append (
132233 _TileableHandler (tileable , tile_handler , need_process )
133234 )
@@ -145,8 +246,8 @@ def _tile(
145246 if chunks is None : # pragma: no cover
146247 raise ValueError (f"tileable({ out } ) is still coarse after tile" )
147248 chunks = [self ._get_data (c ) for c in chunks ]
148- self ._add_nodes (chunk_graph , chunks , visited )
149249 self ._tile_context [out ] = tiled_tileable
250+ self ._add_nodes (chunk_graph , chunks , visited , tileable )
150251
151252 def _gen_result_chunks (
152253 self ,
@@ -227,6 +328,8 @@ def _iter(self):
227328 # prune unused chunks
228329 prune_chunk_graph (chunk_graph )
229330
331+ self ._curr_iter += 1
332+
230333 return to_update_tileables
231334
232335 def __iter__ (self ):
@@ -278,12 +381,13 @@ def __init__(
278381 self ,
279382 graph : TileableGraph ,
280383 fuse_enabled : bool = True ,
281- tile_context : Dict [ TileableType , TileableType ] = None ,
384+ tile_context : TileContext = None ,
282385 tiler_cls : Union [Type [Tiler ], Callable ] = None ,
283386 ):
284387 super ().__init__ (graph )
285388 self .fuse_enabled = fuse_enabled
286- self .tile_context = dict () if tile_context is None else tile_context
389+ self .tile_context = TileContext () if tile_context is None else tile_context
390+ self .tile_context .set_tileables (set (graph ))
287391
288392 self ._processed_chunks : Set [ChunkType ] = set ()
289393 self ._chunk_to_fetch : Dict [ChunkType , ChunkType ] = dict ()
0 commit comments