66import ray
77import ray .actor
88import ray .util .dask .scheduler
9- from dask .core import get_dependencies
109
1110from doreisa import Timestep
1211
@@ -30,6 +29,15 @@ class ChunkRef:
3029 _all_chunks : ray .ObjectRef | None = None
3130
3231
32+ @dataclass
33+ class ScheduledByOtherActor :
34+ """
35+ Represents a task that is scheduled by another actor in the part of the task graph sent to an actor.
36+ """
37+
38+ actor_id : int
39+
40+
3341class GraphInfo :
3442 """
3543 Information about graphs and their scheduling.
@@ -124,7 +132,7 @@ def __init__(self, actor_id: int) -> None:
124132 # For scheduling
125133 self .new_graph_available = asyncio .Event ()
126134 self .graph_infos : dict [int , GraphInfo ] = {}
127- self .partitionned_graphs : dict [int , tuple [ dict , dict [ str , int ]] ] = {}
135+ self .partitionned_graphs : dict [int , dict ] = {}
128136
129137 def ready (self ) -> None :
130138 pass
@@ -188,18 +196,18 @@ async def add_chunk(
188196 else :
189197 await self .chunks_ready_event .wait ()
190198
191- def store_graph (self , graph_id : int , dsk : dict , scheduling : dict [ str , int ] ) -> None :
199+ def store_graph (self , graph_id : int , dsk : dict ) -> None :
192200 """
193201 Store the given graph in the actor until `schedule_graph` is called.
194202
195203 This allows measuring precisely the time it takes to send the graph to all the
196204 actors. If needed, this will be optimized using an efficient communication
197205 method.
198206 """
199- self .partitionned_graphs [graph_id ] = ( dsk , scheduling )
207+ self .partitionned_graphs [graph_id ] = dsk
200208
201209 async def schedule_graph (self , graph_id : int ):
202- dsk , scheduling = self .partitionned_graphs .pop (graph_id )
210+ dsk = self .partitionned_graphs .pop (graph_id )
203211
204212 # Find the scheduling actors
205213 if not self .scheduling_actors :
@@ -210,22 +218,13 @@ async def schedule_graph(self, graph_id: int):
210218 self .new_graph_available .set ()
211219 self .new_graph_available .clear ()
212220
213- local_keys = {k for k in dsk if scheduling [k ] == self .actor_id }
214-
215- dependency_keys : set [str ] = {dep for k in local_keys for dep in get_dependencies (dsk , k )} # type: ignore[assignment]
216-
217- external_keys = dependency_keys - local_keys
218-
219- # Filter the dask array
220- dsk = {k : v for k , v in dsk .items () if k in local_keys }
221-
222- # Adapt external keys
223- for k in external_keys :
224- actor = self .scheduling_actors [scheduling [k ]]
225- dsk [k ] = actor .get_value .options (enable_task_events = False ).remote (graph_id , k )
226-
227- # Replace the false chunks by the real ObjectRefs
228221 for key , val in dsk .items ():
222+ # Adapt external keys
223+ if isinstance (val , ScheduledByOtherActor ):
224+ actor = self .scheduling_actors [val .actor_id ]
225+ dsk [key ] = actor .get_value .options (enable_task_events = False ).remote (graph_id , key )
226+
227+ # Replace the false chunks by the real ObjectRefs
229228 if isinstance (val , ChunkRef ):
230229 assert val .actor_id == self .actor_id
231230
@@ -234,7 +233,7 @@ async def schedule_graph(self, graph_id: int):
234233 dsk [key ] = pickle .loads (encoded_ref )
235234
236235 # We will need the ObjectRefs of these keys
237- keys_needed = list (local_keys - dependency_keys )
236+ keys_needed = list (dsk . keys () )
238237
239238 refs = await remote_ray_dask_get .remote (dsk , keys_needed )
240239
0 commit comments