1515import asyncio
1616import functools
1717import logging
18+ from dataclasses import dataclass
1819from typing import List , Dict , Any , Set
1920from .....core import ChunkGraph , Chunk , TileContext
2021from .....core .context import set_context
3031from .....serialization import serialize , deserialize
3132from .....typing import BandType
3233from .....utils import (
34+ calc_data_size ,
3335 lazy_import ,
3436 get_chunk_params ,
3537 get_chunk_key_to_data_keys ,
5658logger = logging .getLogger (__name__ )
5759
5860
61+ @dataclass
62+ class _RayChunkMeta :
63+ memory_size : int
64+
65+
5966class RayTaskState (RayRemoteObjectManager ):
6067 @classmethod
6168 def gen_name (cls , task_id : str ):
@@ -102,11 +109,14 @@ def execute_subtask(
102109 if output_meta_keys :
103110 output_meta = {}
104111 for chunk in subtask_chunk_graph .result_chunks :
105- if chunk .key in output_meta_keys :
112+ chunk_key = chunk .key
113+ if chunk_key in output_meta_keys and chunk_key not in output_meta :
106114 if isinstance (chunk .op , Fuse ):
107115 # fuse op
108116 chunk = chunk .chunk
109- output_meta [chunk .key ] = get_chunk_params (chunk )
117+ data = context [chunk_key ]
118+ memory_size = calc_data_size (data )
119+ output_meta [chunk_key ] = get_chunk_params (chunk ), memory_size
110120 assert len (output_meta_keys ) == len (output_meta )
111121 output_values .append (output_meta )
112122 output_values .extend (output .values ())
@@ -125,6 +135,7 @@ def __init__(
125135 task : Task ,
126136 tile_context : TileContext ,
127137 task_context : Dict [str , "ray.ObjectRef" ],
138+ task_chunks_meta : Dict [str , _RayChunkMeta ],
128139 task_state_actor : "ray.actor.ActorHandle" ,
129140 lifecycle_api : LifecycleAPI ,
130141 meta_api : MetaAPI ,
@@ -133,6 +144,7 @@ def __init__(
133144 self ._task = task
134145 self ._tile_context = tile_context
135146 self ._task_context = task_context
147+ self ._task_chunks_meta = task_chunks_meta
136148 self ._task_state_actor = task_state_actor
137149 self ._ray_executor = self ._get_ray_executor ()
138150
@@ -166,12 +178,16 @@ async def create(
166178 .remote ()
167179 )
168180 task_context = {}
169- await cls ._init_context (task_context , task_state_actor , session_id , address )
181+ task_chunks_meta = {}
182+ await cls ._init_context (
183+ task_context , task_chunks_meta , task_state_actor , session_id , address
184+ )
170185 return cls (
171186 config ,
172187 task ,
173188 tile_context ,
174189 task_context ,
190+ task_chunks_meta ,
175191 task_state_actor ,
176192 lifecycle_api ,
177193 meta_api ,
@@ -183,6 +199,7 @@ def destroy(self):
183199 self ._task = None
184200 self ._tile_context = None
185201 self ._task_context = None
202+ self ._task_chunks_meta = None
186203 self ._task_state_actor = None
187204 self ._ray_executor = None
188205
@@ -207,7 +224,7 @@ async def _get_apis(cls, session_id: str, address: str):
207224 )
208225
209226 @staticmethod
210- @functools .lru_cache (maxsize = 1 )
227+ @functools .lru_cache (maxsize = None ) # Specify maxsize=None to make it faster
211228 def _get_ray_executor ():
212229 # Export remote function once.
213230 return ray .remote (execute_subtask )
@@ -216,13 +233,15 @@ def _get_ray_executor():
216233 async def _init_context (
217234 cls ,
218235 task_context : Dict [str , "ray.ObjectRef" ],
236+ task_chunks_meta : Dict [str , _RayChunkMeta ],
219237 task_state_actor : "ray.actor.ActorHandle" ,
220238 session_id : str ,
221239 address : str ,
222240 ):
223241 loop = asyncio .get_running_loop ()
224242 context = RayExecutionContext (
225243 task_context ,
244+ task_chunks_meta ,
226245 task_state_actor ,
227246 session_id ,
228247 address ,
@@ -293,7 +312,9 @@ async def execute_subtask_graph(
293312 logger .info ("Getting %s metas of stage %s." , meta_count , stage_id )
294313 meta_list = await asyncio .gather (* output_meta_object_refs )
295314 for meta in meta_list :
296- key_to_meta .update (meta )
315+ for key , (params , memory_size ) in meta .items ():
316+ key_to_meta [key ] = params
317+ self ._task_chunks_meta [key ] = _RayChunkMeta (memory_size = memory_size )
297318 assert len (key_to_meta ) == len (result_meta_keys )
298319 logger .info ("Got %s metas of stage %s." , meta_count , stage_id )
299320
@@ -304,9 +325,9 @@ async def execute_subtask_graph(
304325 chunk_key = chunk .key
305326 object_ref = task_context [chunk_key ]
306327 output_object_refs .add (object_ref )
307- chunk_meta = key_to_meta .get (chunk_key )
308- if chunk_meta is not None :
309- chunk_to_meta [chunk ] = ExecutionChunkResult (chunk_meta , object_ref )
328+ chunk_params = key_to_meta .get (chunk_key )
329+ if chunk_params is not None :
330+ chunk_to_meta [chunk ] = ExecutionChunkResult (chunk_params , object_ref )
310331
311332 logger .info ("Waiting for stage %s complete." , stage_id )
312333 # Patched the asyncio.to_thread for Python < 3.9 at mars/lib/aio/__init__.py
@@ -319,36 +340,42 @@ async def execute_subtask_graph(
319340 return chunk_to_meta
320341
321342 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
322- if exc_type is None :
323- tileable_keys = []
324- update_metas = []
325- update_lifecycles = []
326- for tileable in self ._task .tileable_graph .result_tileables :
327- tileable_keys .append (tileable .key )
328- tileable = tileable .data if hasattr (tileable , "data" ) else tileable
329- chunk_keys = []
330- for chunk in self ._tile_context [tileable ].chunks :
331- chunk_keys .append (chunk .key )
332- if chunk .key in self ._task_context :
333- # Some tileable graph may have result chunks that not be executed,
334- # for example:
335- # r, b = cut(series, bins, retbins=True)
336- # r_result = r.execute().fetch()
337- # b_result = b.execute().fetch() <- This is the case
338- object_ref = self ._task_context [chunk .key ]
339- update_metas .append (
340- self ._meta_api .set_chunk_meta .delay (
341- chunk ,
342- bands = [],
343- object_ref = object_ref ,
344- )
343+ if exc_type is not None :
344+ return
345+
346+ # Update info if no exception occurs.
347+ tileable_keys = []
348+ update_metas = []
349+ update_lifecycles = []
350+ for tileable in self ._task .tileable_graph .result_tileables :
351+ tileable_keys .append (tileable .key )
352+ tileable = tileable .data if hasattr (tileable , "data" ) else tileable
353+ chunk_keys = []
354+ for chunk in self ._tile_context [tileable ].chunks :
355+ chunk_key = chunk .key
356+ chunk_keys .append (chunk_key )
357+ if chunk_key in self ._task_context :
358+ # Some tileable graph may have result chunks that not be executed,
359+ # for example:
360+ # r, b = cut(series, bins, retbins=True)
361+ # r_result = r.execute().fetch()
362+ # b_result = b.execute().fetch() <- This is the case
363+ object_ref = self ._task_context [chunk_key ]
364+ chunk_meta = self ._task_chunks_meta [chunk_key ]
365+ update_metas .append (
366+ self ._meta_api .set_chunk_meta .delay (
367+ chunk ,
368+ bands = [],
369+ object_ref = object_ref ,
370+ memory_size = chunk_meta .memory_size ,
345371 )
346- update_lifecycles .append (
347- self ._lifecycle_api .track .delay (tileable .key , chunk_keys )
348372 )
349- await self ._meta_api .set_chunk_meta .batch (* update_metas )
350- await self ._lifecycle_api .track .batch (* update_lifecycles )
351- await self ._lifecycle_api .incref_tileables (tileable_keys )
373+ update_lifecycles .append (
374+ self ._lifecycle_api .track .delay (tileable .key , chunk_keys )
375+ )
376+ await self ._meta_api .set_chunk_meta .batch (* update_metas )
377+ await self ._lifecycle_api .track .batch (* update_lifecycles )
378+ await self ._lifecycle_api .incref_tileables (tileable_keys )
352379
353380 async def get_available_band_resources (self ) -> Dict [BandType , Resource ]:
354381 if self ._available_band_resources is None :
0 commit comments