1313# limitations under the License.
1414
1515import asyncio
16+ import functools
1617import logging
1718from typing import List , Dict , Any , Set
1819from .....core import ChunkGraph , Chunk , TileContext
@@ -123,22 +124,22 @@ def __init__(
123124 config : ExecutionConfig ,
124125 task : Task ,
125126 tile_context : TileContext ,
126- ray_executor : "ray.remote_function.RemoteFunction" ,
127+ task_context : Dict [ str , "ray.ObjectRef" ] ,
127128 task_state_actor : "ray.actor.ActorHandle" ,
128129 lifecycle_api : LifecycleAPI ,
129130 meta_api : MetaAPI ,
130131 ):
131132 self ._config = config
132133 self ._task = task
133134 self ._tile_context = tile_context
134- self ._ray_executor = ray_executor
135+ self ._task_context = task_context
135136 self ._task_state_actor = task_state_actor
137+ self ._ray_executor = self ._get_ray_executor ()
136138
137139 # api
138140 self ._lifecycle_api = lifecycle_api
139141 self ._meta_api = meta_api
140142
141- self ._task_context = {}
142143 self ._available_band_resources = None
143144
144145 # For progress
@@ -158,19 +159,19 @@ async def create(
158159 tile_context : TileContext ,
159160 ** kwargs ,
160161 ) -> "TaskExecutor" :
161- ray_executor = ray .remote (execute_subtask )
162162 lifecycle_api , meta_api = await cls ._get_apis (session_id , address )
163163 task_state_actor = (
164164 ray .remote (RayTaskState )
165165 .options (name = RayTaskState .gen_name (task .task_id ))
166166 .remote ()
167167 )
168- await cls ._init_context (task_state_actor , session_id , address )
168+ task_context = {}
169+ await cls ._init_context (task_context , task_state_actor , session_id , address )
169170 return cls (
170171 config ,
171172 task ,
172173 tile_context ,
173- ray_executor ,
174+ task_context ,
174175 task_state_actor ,
175176 lifecycle_api ,
176177 meta_api ,
@@ -184,13 +185,29 @@ async def _get_apis(cls, session_id: str, address: str):
184185 MetaAPI .create (session_id , address ),
185186 )
186187
188+ @staticmethod
189+ @functools .lru_cache (maxsize = 1 )
190+ def _get_ray_executor ():
191+ # Export remote function once.
192+ return ray .remote (execute_subtask )
193+
187194 @classmethod
188195 async def _init_context (
189- cls , task_state_actor : "ray.actor.ActorHandle" , session_id : str , address : str
196+ cls ,
197+ task_context : Dict [str , "ray.ObjectRef" ],
198+ task_state_actor : "ray.actor.ActorHandle" ,
199+ session_id : str ,
200+ address : str ,
190201 ):
191202 loop = asyncio .get_running_loop ()
192203 context = RayExecutionContext (
193- task_state_actor , session_id , address , address , address , loop = loop
204+ task_context ,
205+ task_state_actor ,
206+ session_id ,
207+ address ,
208+ address ,
209+ address ,
210+ loop = loop ,
194211 )
195212 await context .init ()
196213 set_context (context )
@@ -204,7 +221,7 @@ async def execute_subtask_graph(
204221 context : Any = None ,
205222 ) -> Dict [Chunk , ExecutionChunkResult ]:
206223 logger .info ("Stage %s start." , stage_id )
207- context = self ._task_context
224+ task_context = self ._task_context
208225 output_meta_object_refs = []
209226 self ._pre_all_stages_tile_progress = (
210227 self ._pre_all_stages_tile_progress + self ._cur_stage_tile_progress
@@ -221,7 +238,7 @@ async def execute_subtask_graph(
221238 for subtask in subtask_graph .topological_iter ():
222239 subtask_chunk_graph = subtask .chunk_graph
223240 key_to_input = await self ._load_subtask_inputs (
224- stage_id , subtask , subtask_chunk_graph , context
241+ stage_id , subtask , subtask_chunk_graph , task_context
225242 )
226243 output_keys = self ._get_subtask_output_keys (subtask_chunk_graph )
227244 output_meta_keys = result_meta_keys & output_keys
@@ -245,32 +262,34 @@ async def execute_subtask_graph(
245262 meta_object_ref , * output_object_refs = output_object_refs
246263 # TODO(fyrestone): Fetch(not get) meta object here.
247264 output_meta_object_refs .append (meta_object_ref )
248- context .update (zip (output_keys , output_object_refs ))
265+ task_context .update (zip (output_keys , output_object_refs ))
249266 logger .info ("Submitted %s subtasks of stage %s." , len (subtask_graph ), stage_id )
250267
251268 key_to_meta = {}
252269 if len (output_meta_object_refs ) > 0 :
253270 # TODO(fyrestone): Optimize update meta by fetching partial meta.
271+ meta_count = len (output_meta_object_refs )
272+ logger .info ("Getting %s metas of stage %s." , meta_count , stage_id )
254273 meta_list = await asyncio .gather (* output_meta_object_refs )
255274 for meta in meta_list :
256275 key_to_meta .update (meta )
257276 assert len (key_to_meta ) == len (result_meta_keys )
258- logger .info (
259- "Got %s metas of stage %s." , len (output_meta_object_refs ), stage_id
260- )
277+ logger .info ("Got %s metas of stage %s." , meta_count , stage_id )
261278
262279 chunk_to_meta = {}
263- output_object_refs = []
280+ # ray.wait requires the object ref list is unique.
281+ output_object_refs = set ()
264282 for chunk in chunk_graph .result_chunks :
265283 chunk_key = chunk .key
266- object_ref = context [chunk_key ]
267- output_object_refs .append (object_ref )
284+ object_ref = task_context [chunk_key ]
285+ output_object_refs .add (object_ref )
268286 chunk_meta = key_to_meta .get (chunk_key )
269287 if chunk_meta is not None :
270288 chunk_to_meta [chunk ] = ExecutionChunkResult (chunk_meta , object_ref )
271289
272290 logger .info ("Waiting for stage %s complete." , stage_id )
273- ray .wait (output_object_refs , fetch_local = False )
291+ # Patched the asyncio.to_thread for Python < 3.9 at mars/lib/aio/__init__.py
292+ await asyncio .to_thread (ray .wait , list (output_object_refs ), fetch_local = False )
274293 # Just use `self._cur_stage_tile_progress` as current stage progress
275294 # because current stage is finished, its progress is 1.
276295 self ._pre_all_stages_progress += self ._cur_stage_tile_progress
@@ -289,14 +308,20 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
289308 chunk_keys = []
290309 for chunk in self ._tile_context [tileable ].chunks :
291310 chunk_keys .append (chunk .key )
292- object_ref = self ._task_context [chunk .key ]
293- update_metas .append (
294- self ._meta_api .set_chunk_meta .delay (
295- chunk ,
296- bands = [],
297- object_ref = object_ref ,
311+ if chunk .key in self ._task_context :
312+ # Some tileable graph may have result chunks that not be executed,
313+ # for example:
314+ # r, b = cut(series, bins, retbins=True)
315+ # r_result = r.execute().fetch()
316+ # b_result = b.execute().fetch() <- This is the case
317+ object_ref = self ._task_context [chunk .key ]
318+ update_metas .append (
319+ self ._meta_api .set_chunk_meta .delay (
320+ chunk ,
321+ bands = [],
322+ object_ref = object_ref ,
323+ )
298324 )
299- )
300325 update_lifecycles .append (
301326 self ._lifecycle_api .track .delay (tileable .key , chunk_keys )
302327 )
@@ -325,7 +350,7 @@ async def get_progress(self) -> float:
325350 finished_objects , _ = ray .wait (
326351 self ._cur_stage_output_object_refs ,
327352 num_returns = total ,
328- timeout = 0.1 ,
353+ timeout = 0 , # Avoid blocking the asyncio loop.
329354 fetch_local = False ,
330355 )
331356 stage_progress = (
0 commit comments