1414
1515from doreisa import Timestep
1616from doreisa ._scheduler import doreisa_get
17- from doreisa ._scheduling_actor import ChunkReadyInfo , ChunkRef , SchedulingActor
17+ from doreisa ._scheduling_actor import ChunkRef , SchedulingActor
1818
1919
2020def init ():
@@ -39,9 +39,8 @@ class _DaskArrayData:
3939 Information about a Dask array being built.
4040 """
4141
42- def __init__ (self , definition : ArrayDefinition , timestep : Timestep ) -> None :
42+ def __init__ (self , definition : ArrayDefinition ) -> None :
4343 self .definition = definition
44- self .timestep = timestep
4544
4645 # This will be set when the first chunk is added
4746 self .nb_chunks_per_dim : tuple [int , ...] | None = None
@@ -56,28 +55,25 @@ def __init__(self, definition: ArrayDefinition, timestep: Timestep) -> None:
5655 # ID of the scheduling actor in charge of the chunk at each position
5756 self .scheduling_actors_id : dict [tuple [int , ...], int ] = {}
5857
58+ # Number of scheduling actors owning chunks of this array.
59+ self .nb_scheduling_actors : int | None = None
60+
5961 # Each reference comes from one scheduling actor. The reference a list of
6062 # ObjectRefs, each ObjectRef corresponding to a chunk. These references
6163 # shouldn't be used directly. They exists only to release the memory
6264 # automatically.
6365 # When the array is buit, these references are put in the object store, and the
6466 # global reference is added to the Dask graph. Then, the list is cleared.
65- self .chunk_refs : list [ray .ObjectRef ] = []
67+ self .chunk_refs : dict [ Timestep , list [ray .ObjectRef ]] = {}
6668
67- def add_chunk (
69+ def set_chunk_owner (
6870 self ,
69- size : tuple [int , ...],
70- position : tuple [int , ...],
71- dtype : np .dtype ,
7271 nb_chunks_per_dim : tuple [int , ...],
72+ dtype : np .dtype ,
73+ position : tuple [int , ...],
74+ size : tuple [int , ...],
7375 scheduling_actor_id : int ,
74- ) -> bool :
75- """
76- Add a chunk to the array.
77-
78- Return:
79- True if the array is ready, False otherwise.
80- """
76+ ) -> None :
8177 if self .nb_chunks_per_dim is None :
8278 self .nb_chunks_per_dim = nb_chunks_per_dim
8379 self .nb_chunks = math .prod (nb_chunks_per_dim )
@@ -100,31 +96,43 @@ def add_chunk(
10096 else :
10197 assert self .chunks_size [d ][position [d ]] == size [d ]
10298
103- if len (self .scheduling_actors_id ) == self .nb_chunks : # The array is ready
104- return True
105- return False
99+ def add_chunk_ref (self , chunk_ref : ray .ObjectRef , timestep : Timestep ) -> bool :
100+ """
101+ Add a reference sent by a scheduling actor.
102+
103+ Return:
104+ True if all the chunks for this timestep are ready, False otherwise.
105+ """
106+ self .chunk_refs [timestep ].append (chunk_ref )
106107
107- def add_chunk_ref (self , chunk_ref : ray .ObjectRef ) -> None :
108- self .chunk_refs .append (chunk_ref )
108+ # We don't know all the owners yet
109+ if len (self .scheduling_actors_id ) != self .nb_chunks :
110+ return False
109111
110- def get_full_array (self ) -> da .Array :
112+ if self .nb_scheduling_actors is None :
113+ self .nb_scheduling_actors = len (set (self .scheduling_actors_id .values ()))
114+
115+ return len (self .chunk_refs [timestep ]) == self .nb_scheduling_actors
116+
117+ def get_full_array (self , timestep : Timestep ) -> da .Array :
111118 """
112119 Return the full Dask array.
113120 """
114121 assert len (self .scheduling_actors_id ) == self .nb_chunks
115122 assert self .nb_chunks is not None and self .nb_chunks_per_dim is not None
116123
117- all_chunks = ray .put (self .chunk_refs )
124+ all_chunks = ray .put (self .chunk_refs [timestep ])
125+ del self .chunk_refs [timestep ]
118126
119127 # We need to add the timestep since the same name can be used several times for different
120128 # timesteps
121- dask_name = f"{ self .definition .name } _{ self . timestep } "
129+ dask_name = f"{ self .definition .name } _{ timestep } "
122130
123131 graph = {
124132 # We need to repeat the name and position in the value since the key might be removed
125133 # by the Dask optimizer
126134 (dask_name ,) + position : ChunkRef (
127- actor_id , self .definition .name , self . timestep , position , _all_chunks = all_chunks if it == 0 else None
135+ actor_id , self .definition .name , timestep , position , _all_chunks = all_chunks if it == 0 else None
128136 )
129137 for it , (position , actor_id ) in enumerate (self .scheduling_actors_id .items ())
130138 }
@@ -177,21 +185,18 @@ def __init__(self, arrays_definitions: list[ArrayDefinition], max_pending_arrays
177185 # For each ID of a simulation node, the corresponding scheduling actor
178186 self .scheduling_actors : dict [str , ray .actor .ActorHandle ] = {}
179187
180- self .arrays_definition : dict [str , ArrayDefinition ] = {
181- definition .name : definition for definition in arrays_definitions
182- }
183-
184- # Must be used before creating a new array
188+ # Must be used before creating a new array, to prevent the simulation from being
189+ # too many iterations in advance of the analytics.
185190 self .new_pending_array_semaphore = asyncio .Semaphore (max_pending_arrays )
186191
187- # Triggered when a new array is added to self.arrays
188192 self .new_array_created = asyncio .Event ()
189193
190- # Arrays beeing built
191- self .arrays : dict [tuple [str , Timestep ], _DaskArrayData ] = {}
194+ self .arrays : dict [str , _DaskArrayData ] = {
195+ definition .name : _DaskArrayData (definition ) for definition in arrays_definitions
196+ }
192197
193198 # All the newly created arrays
194- self .arrays_ready : asyncio .Queue [tuple [str , int , da .Array ]] = asyncio .Queue ()
199+ self .arrays_ready : asyncio .Queue [tuple [str , Timestep , da .Array ]] = asyncio .Queue ()
195200
196201 def list_scheduling_actors (self ) -> list [ray .actor .ActorHandle ]:
197202 """
@@ -233,11 +238,22 @@ def preprocessing_callbacks(self) -> dict[str, Callable]:
233238 """
234239 Return the preprocessing callbacks for each array.
235240 """
236- return {name : definition .preprocess for name , definition in self .arrays_definition .items ()}
241+ return {name : array . definition .preprocess for name , array in self .arrays .items ()}
237242
238- async def chunks_ready (
239- self , chunks : list [ChunkReadyInfo ], scheduling_actor_id : int , all_chunks_ref : list [ray .ObjectRef ]
240- ) -> None :
243+ def set_owned_chunks (
244+ self ,
245+ scheduling_actor_id : int ,
246+ array_name : str ,
247+ dtype : np .dtype ,
248+ nb_chunks_per_dim : tuple [int , ...],
249+ chunks : list [tuple [tuple [int , ...], tuple [int , ...]]], # [(chunk position, chunk size), ...]
250+ ):
251+ array = self .arrays [array_name ]
252+
253+ for position , size in chunks :
254+ array .set_chunk_owner (nb_chunks_per_dim , dtype , position , size , scheduling_actor_id )
255+
256+ async def chunks_ready (self , array_name : str , timestep : Timestep , all_chunks_ref : list [ray .ObjectRef ]) -> None :
241257 """
242258 Called by the scheduling actors to inform the head actor that the chunks are ready.
243259 The chunks are not sent.
@@ -246,49 +262,39 @@ async def chunks_ready(
246262 chunks: Information about the chunks that are ready.
247263 source_actor: Handle to the scheduling actor owning the chunks.
248264 """
249- for it , chunk in enumerate (chunks ):
250- while (chunk .array_name , chunk .timestep ) not in self .arrays :
251- t1 = asyncio .create_task (self .new_pending_array_semaphore .acquire ())
252- t2 = asyncio .create_task (self .new_array_created .wait ())
265+ array = self .arrays [array_name ]
253266
254- done , pending = await asyncio .wait ([t1 , t2 ], return_when = asyncio .FIRST_COMPLETED )
267+ while timestep not in array .chunk_refs :
268+ t1 = asyncio .create_task (self .new_pending_array_semaphore .acquire ())
269+ t2 = asyncio .create_task (self .new_array_created .wait ())
255270
256- for task in pending :
257- task .cancel ()
271+ done , pending = await asyncio .wait ([t1 , t2 ], return_when = asyncio .FIRST_COMPLETED )
258272
259- if t1 in done :
260- if (chunk .array_name , chunk .timestep ) in self .arrays :
261- # The array was already created by another scheduling actor
262- self .new_pending_array_semaphore .release ()
263- else :
264- self .arrays [(chunk .array_name , chunk .timestep )] = _DaskArrayData (
265- self .arrays_definition [chunk .array_name ], chunk .timestep
266- )
273+ for task in pending :
274+ task .cancel ()
267275
268- self .new_array_created .set ()
269- self .new_array_created .clear ()
276+ if t1 in done :
277+ if timestep in array .chunk_refs :
278+ # The array was already created by another scheduling actor
279+ self .new_pending_array_semaphore .release ()
280+ else :
281+ array .chunk_refs [timestep ] = []
270282
271- array = self .arrays [(chunk .array_name , chunk .timestep )]
283+ self .new_array_created .set ()
284+ self .new_array_created .clear ()
272285
273- # TODO refactor so that the function works with only one array
274- if it == 0 :
275- array .add_chunk_ref (all_chunks_ref [0 ])
286+ is_ready = array .add_chunk_ref (all_chunks_ref [0 ], timestep )
276287
277- is_ready = array .add_chunk (
278- chunk .size , chunk .position , chunk .dtype , chunk .nb_chunks_per_dim , scheduling_actor_id
279- )
280-
281- if is_ready :
282- self .arrays_ready .put_nowait (
283- (
284- chunk .array_name ,
285- array .timestep ,
286- array .get_full_array (),
287- )
288+ if is_ready :
289+ self .arrays_ready .put_nowait (
290+ (
291+ array_name ,
292+ timestep ,
293+ array .get_full_array (timestep ),
288294 )
289- del self . arrays [( chunk . array_name , chunk . timestep )]
295+ )
290296
291- async def get_next_array (self ) -> tuple [str , int , da .Array ]:
297+ async def get_next_array (self ) -> tuple [str , Timestep , da .Array ]:
292298 array = await self .arrays_ready .get ()
293299 self .new_pending_array_semaphore .release ()
294300 return array
0 commit comments