@@ -124,7 +124,7 @@ def __init__(
124124 self ,
125125 tileable_graph : TileableGraph ,
126126 tile_context : TileContext ,
127- processed_chunks : Set [ChunkType ],
127+ processed_chunks : Set [str ],
128128 chunk_to_fetch : Dict [ChunkType , ChunkType ],
129129 add_nodes : Callable ,
130130 ):
@@ -301,11 +301,12 @@ def _iter(self):
301301
302302 if chunk_graph is not None :
303303 # last tiled chunks, add them to processed
304- # so that fetch chunk can be generated
305- processed_chunks = [
306- c .chunk if isinstance (c , FUSE_CHUNK_TYPE ) else c
304+ # so that fetch chunk can be generated.
305+ # Use chunk key as the key to make sure the copied chunk can be build to a fetch.
306+ processed_chunks = (
307+ c .chunk .key if isinstance (c , FUSE_CHUNK_TYPE ) else c .key
307308 for c in chunk_graph .result_chunks
308- ]
309+ )
309310 self ._processed_chunks .update (processed_chunks )
310311
311312 result_chunks = []
@@ -389,7 +390,7 @@ def __init__(
389390 self .tile_context = TileContext () if tile_context is None else tile_context
390391 self .tile_context .set_tileables (set (graph ))
391392
392- self ._processed_chunks : Set [ChunkType ] = set ()
393+ self ._processed_chunks : Set [str ] = set ()
393394 self ._chunk_to_fetch : Dict [ChunkType , ChunkType ] = dict ()
394395
395396 tiler_cls = Tiler if tiler_cls is None else tiler_cls
@@ -402,7 +403,7 @@ def __init__(
402403 )
403404
404405 def _process_node (self , entity : EntityType ):
405- if entity in self ._processed_chunks :
406+ if entity . key in self ._processed_chunks :
406407 if entity not in self ._chunk_to_fetch :
407408 # gen fetch
408409 fetch_chunk = build_fetch (entity ).data
@@ -413,7 +414,7 @@ def _process_node(self, entity: EntityType):
413414 def _select_inputs (self , inputs : List [ChunkType ]):
414415 new_inputs = []
415416 for inp in inputs :
416- if inp in self ._processed_chunks :
417+ if inp . key in self ._processed_chunks :
417418 # gen fetch
418419 if inp not in self ._chunk_to_fetch :
419420 fetch_chunk = build_fetch (inp ).data
@@ -424,7 +425,7 @@ def _select_inputs(self, inputs: List[ChunkType]):
424425 return new_inputs
425426
426427 def _if_add_node (self , node : EntityType , visited : Set ):
427- return node not in visited and node not in self ._processed_chunks
428+ return node not in visited and node . key not in self ._processed_chunks
428429
429430 def _build (self ) -> Iterable [Union [TileableGraph , ChunkGraph ]]:
430431 tile_iterator = iter (self .tiler )
0 commit comments