Skip to content

Commit 5237e06

Browse files
Fix type annotation (#29)
* Fix type annotation * First refactoring * Fix bug * Refactor scheduling actor * Remove useless variable * Remove useless parameter * Add comment * Remove useless field * Bump to 0.2.0
1 parent 515a559 commit 5237e06

File tree

4 files changed

+134
-123
lines changed

4 files changed

+134
-123
lines changed

doreisa/_scheduling_actor.py

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,6 @@ def __init__(self):
4848
self.refs: dict[str, ray.ObjectRef] = {}
4949

5050

51-
@dataclass
52-
class ChunkReadyInfo:
53-
# Information about the array
54-
array_name: str
55-
timestep: Timestep
56-
dtype: np.dtype
57-
nb_chunks_per_dim: tuple[int, ...]
58-
59-
# Information about the chunk
60-
position: tuple[int, ...]
61-
size: tuple[int, ...]
62-
63-
6451
@ray.remote(num_cpus=0, enable_task_events=False)
6552
def patched_dask_task_wrapper(func, repack, key, ray_pretask_cbs, ray_posttask_cbs, *args, first_call=True):
6653
"""
@@ -105,6 +92,27 @@ def remote_ray_dask_get(dsk, keys):
10592
return ray.util.dask.ray_dask_get(dsk, keys, ray_persist=True)
10693

10794

95+
class _ArrayTimestep:
96+
def __init__(self):
97+
# Triggered when all the chunks are ready
98+
self.chunks_ready_event: asyncio.Event = asyncio.Event()
99+
100+
# {position: chunk}
101+
self.local_chunks: dict[tuple[int, ...], ray.ObjectRef | bytes] = {}
102+
103+
104+
class _Array:
105+
def __init__(self):
106+
# Indicates if set_owned_chunks method has been called for this array.
107+
self.is_registered = False
108+
109+
# Chunks owned by this actor for this array.
110+
# {(chunk position, chunk size), ...}
111+
self.owned_chunks: set[tuple[tuple[int, ...], tuple[int, ...]]] = set()
112+
113+
self.timesteps: dict[Timestep, _ArrayTimestep] = {}
114+
115+
108116
@ray.remote
109117
class SchedulingActor:
110118
"""
@@ -119,15 +127,7 @@ def __init__(self, actor_id: int) -> None:
119127
self.scheduling_actors: list[ray.actor.ActorHandle] = []
120128

121129
# For collecting chunks
122-
123-
# Triggered when all the chunks are ready
124-
self.chunks_ready_event = asyncio.Event()
125-
126-
self.chunks_info: dict[str, list[ChunkReadyInfo]] = {}
127-
128-
# (dask_array_name, position) -> chunk
129-
# The Dask array name contains the timestep
130-
self.local_chunks: dict[tuple[str, Timestep, tuple[int, ...]], ray.ObjectRef | bytes] = {}
130+
self.arrays: dict[str, _Array] = {}
131131

132132
# For scheduling
133133
self.new_graph_available = asyncio.Event()
@@ -158,43 +158,48 @@ async def add_chunk(
158158
chunk: list[ray.ObjectRef],
159159
chunk_shape: tuple[int, ...],
160160
) -> None:
161-
assert (array_name, timestep, chunk_position) not in self.local_chunks
162-
163-
self.local_chunks[(array_name, timestep, chunk_position)] = self.actor_handle._pack_object_ref.remote(chunk)
164-
165-
if array_name not in self.chunks_info:
166-
self.chunks_info[array_name] = []
167-
chunks_info = self.chunks_info[array_name]
168-
169-
chunks_info.append(
170-
ChunkReadyInfo(
171-
array_name=array_name,
172-
timestep=timestep,
173-
dtype=dtype,
174-
nb_chunks_per_dim=nb_chunks_per_dim,
175-
position=chunk_position,
176-
size=chunk_shape,
177-
)
178-
)
161+
if array_name not in self.arrays:
162+
self.arrays[array_name] = _Array()
163+
array = self.arrays[array_name]
164+
165+
if timestep not in array.timesteps:
166+
array.timesteps[timestep] = _ArrayTimestep()
167+
array_timestep = array.timesteps[timestep]
168+
169+
assert chunk_position not in array_timestep.local_chunks
170+
array_timestep.local_chunks[chunk_position] = self.actor_handle._pack_object_ref.remote(chunk)
171+
172+
array.owned_chunks.add((chunk_position, chunk_shape))
173+
174+
if len(array_timestep.local_chunks) == nb_chunks_of_node:
175+
if not array.is_registered:
176+
# Register the array with the head node
177+
await self.head.set_owned_chunks.options(enable_task_events=False).remote(
178+
self.actor_id,
179+
array_name,
180+
dtype,
181+
nb_chunks_per_dim,
182+
list(array.owned_chunks),
183+
)
184+
array.is_registered = True
179185

180-
if len(chunks_info) == nb_chunks_of_node:
181186
chunks = []
182-
for info in chunks_info:
183-
c = self.local_chunks[(info.array_name, info.timestep, info.position)]
187+
for position, size in array.owned_chunks:
188+
c = array_timestep.local_chunks[position]
184189
assert isinstance(c, ray.ObjectRef)
185190
chunks.append(c)
186-
self.local_chunks[(info.array_name, info.timestep, info.position)] = pickle.dumps(c)
191+
array_timestep.local_chunks[position] = pickle.dumps(c)
187192

188193
all_chunks_ref = ray.put(chunks)
189194

190195
await self.head.chunks_ready.options(enable_task_events=False).remote(
191-
chunks_info, self.actor_id, [all_chunks_ref]
196+
array_name, timestep, [all_chunks_ref]
192197
)
193-
self.chunks_info[array_name] = []
194-
self.chunks_ready_event.set()
195-
self.chunks_ready_event.clear()
198+
199+
array_timestep.chunks_ready_event.set()
200+
array_timestep.chunks_ready_event.clear()
196201
else:
197-
await self.chunks_ready_event.wait()
202+
await array_timestep.chunks_ready_event.wait()
198203

199204
def store_graph(self, graph_id: int, dsk: dict) -> None:
200205
"""
@@ -228,7 +233,7 @@ async def schedule_graph(self, graph_id: int):
228233
if isinstance(val, ChunkRef):
229234
assert val.actor_id == self.actor_id
230235

231-
encoded_ref = self.local_chunks[(val.array_name, val.timestep, val.position)]
236+
encoded_ref = self.arrays[val.array_name].timesteps[val.timestep].local_chunks[val.position]
232237
assert isinstance(encoded_ref, bytes)
233238
dsk[key] = pickle.loads(encoded_ref)
234239

doreisa/head_node.py

Lines changed: 77 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from doreisa import Timestep
1616
from doreisa._scheduler import doreisa_get
17-
from doreisa._scheduling_actor import ChunkReadyInfo, ChunkRef, SchedulingActor
17+
from doreisa._scheduling_actor import ChunkRef, SchedulingActor
1818

1919

2020
def init():
@@ -39,9 +39,8 @@ class _DaskArrayData:
3939
Information about a Dask array being built.
4040
"""
4141

42-
def __init__(self, definition: ArrayDefinition, timestep: Timestep) -> None:
42+
def __init__(self, definition: ArrayDefinition) -> None:
4343
self.definition = definition
44-
self.timestep = timestep
4544

4645
# This will be set when the first chunk is added
4746
self.nb_chunks_per_dim: tuple[int, ...] | None = None
@@ -56,28 +55,25 @@ def __init__(self, definition: ArrayDefinition, timestep: Timestep) -> None:
5655
# ID of the scheduling actor in charge of the chunk at each position
5756
self.scheduling_actors_id: dict[tuple[int, ...], int] = {}
5857

58+
# Number of scheduling actors owning chunks of this array.
59+
self.nb_scheduling_actors: int | None = None
60+
5961
# Each reference comes from one scheduling actor. The reference a list of
6062
# ObjectRefs, each ObjectRef corresponding to a chunk. These references
6163
# shouldn't be used directly. They exists only to release the memory
6264
# automatically.
6365
# When the array is buit, these references are put in the object store, and the
6466
# global reference is added to the Dask graph. Then, the list is cleared.
65-
self.chunk_refs: list[ray.ObjectRef] = []
67+
self.chunk_refs: dict[Timestep, list[ray.ObjectRef]] = {}
6668

67-
def add_chunk(
69+
def set_chunk_owner(
6870
self,
69-
size: tuple[int, ...],
70-
position: tuple[int, ...],
71-
dtype: np.dtype,
7271
nb_chunks_per_dim: tuple[int, ...],
72+
dtype: np.dtype,
73+
position: tuple[int, ...],
74+
size: tuple[int, ...],
7375
scheduling_actor_id: int,
74-
) -> bool:
75-
"""
76-
Add a chunk to the array.
77-
78-
Return:
79-
True if the array is ready, False otherwise.
80-
"""
76+
) -> None:
8177
if self.nb_chunks_per_dim is None:
8278
self.nb_chunks_per_dim = nb_chunks_per_dim
8379
self.nb_chunks = math.prod(nb_chunks_per_dim)
@@ -100,31 +96,43 @@ def add_chunk(
10096
else:
10197
assert self.chunks_size[d][position[d]] == size[d]
10298

103-
if len(self.scheduling_actors_id) == self.nb_chunks: # The array is ready
104-
return True
105-
return False
99+
def add_chunk_ref(self, chunk_ref: ray.ObjectRef, timestep: Timestep) -> bool:
100+
"""
101+
Add a reference sent by a scheduling actor.
102+
103+
Return:
104+
True if all the chunks for this timestep are ready, False otherwise.
105+
"""
106+
self.chunk_refs[timestep].append(chunk_ref)
106107

107-
def add_chunk_ref(self, chunk_ref: ray.ObjectRef) -> None:
108-
self.chunk_refs.append(chunk_ref)
108+
# We don't know all the owners yet
109+
if len(self.scheduling_actors_id) != self.nb_chunks:
110+
return False
109111

110-
def get_full_array(self) -> da.Array:
112+
if self.nb_scheduling_actors is None:
113+
self.nb_scheduling_actors = len(set(self.scheduling_actors_id.values()))
114+
115+
return len(self.chunk_refs[timestep]) == self.nb_scheduling_actors
116+
117+
def get_full_array(self, timestep: Timestep) -> da.Array:
111118
"""
112119
Return the full Dask array.
113120
"""
114121
assert len(self.scheduling_actors_id) == self.nb_chunks
115122
assert self.nb_chunks is not None and self.nb_chunks_per_dim is not None
116123

117-
all_chunks = ray.put(self.chunk_refs)
124+
all_chunks = ray.put(self.chunk_refs[timestep])
125+
del self.chunk_refs[timestep]
118126

119127
# We need to add the timestep since the same name can be used several times for different
120128
# timesteps
121-
dask_name = f"{self.definition.name}_{self.timestep}"
129+
dask_name = f"{self.definition.name}_{timestep}"
122130

123131
graph = {
124132
# We need to repeat the name and position in the value since the key might be removed
125133
# by the Dask optimizer
126134
(dask_name,) + position: ChunkRef(
127-
actor_id, self.definition.name, self.timestep, position, _all_chunks=all_chunks if it == 0 else None
135+
actor_id, self.definition.name, timestep, position, _all_chunks=all_chunks if it == 0 else None
128136
)
129137
for it, (position, actor_id) in enumerate(self.scheduling_actors_id.items())
130138
}
@@ -177,21 +185,18 @@ def __init__(self, arrays_definitions: list[ArrayDefinition], max_pending_arrays
177185
# For each ID of a simulation node, the corresponding scheduling actor
178186
self.scheduling_actors: dict[str, ray.actor.ActorHandle] = {}
179187

180-
self.arrays_definition: dict[str, ArrayDefinition] = {
181-
definition.name: definition for definition in arrays_definitions
182-
}
183-
184-
# Must be used before creating a new array
188+
# Must be used before creating a new array, to prevent the simulation from being
189+
# too many iterations in advance of the analytics.
185190
self.new_pending_array_semaphore = asyncio.Semaphore(max_pending_arrays)
186191

187-
# Triggered when a new array is added to self.arrays
188192
self.new_array_created = asyncio.Event()
189193

190-
# Arrays beeing built
191-
self.arrays: dict[tuple[str, Timestep], _DaskArrayData] = {}
194+
self.arrays: dict[str, _DaskArrayData] = {
195+
definition.name: _DaskArrayData(definition) for definition in arrays_definitions
196+
}
192197

193198
# All the newly created arrays
194-
self.arrays_ready: asyncio.Queue[tuple[str, int, da.Array]] = asyncio.Queue()
199+
self.arrays_ready: asyncio.Queue[tuple[str, Timestep, da.Array]] = asyncio.Queue()
195200

196201
def list_scheduling_actors(self) -> list[ray.actor.ActorHandle]:
197202
"""
@@ -233,11 +238,22 @@ def preprocessing_callbacks(self) -> dict[str, Callable]:
233238
"""
234239
Return the preprocessing callbacks for each array.
235240
"""
236-
return {name: definition.preprocess for name, definition in self.arrays_definition.items()}
241+
return {name: array.definition.preprocess for name, array in self.arrays.items()}
237242

238-
async def chunks_ready(
239-
self, chunks: list[ChunkReadyInfo], scheduling_actor_id: int, all_chunks_ref: list[ray.ObjectRef]
240-
) -> None:
243+
def set_owned_chunks(
244+
self,
245+
scheduling_actor_id: int,
246+
array_name: str,
247+
dtype: np.dtype,
248+
nb_chunks_per_dim: tuple[int, ...],
249+
chunks: list[tuple[tuple[int, ...], tuple[int, ...]]], # [(chunk position, chunk size), ...]
250+
):
251+
array = self.arrays[array_name]
252+
253+
for position, size in chunks:
254+
array.set_chunk_owner(nb_chunks_per_dim, dtype, position, size, scheduling_actor_id)
255+
256+
async def chunks_ready(self, array_name: str, timestep: Timestep, all_chunks_ref: list[ray.ObjectRef]) -> None:
241257
"""
242258
Called by the scheduling actors to inform the head actor that the chunks are ready.
243259
The chunks are not sent.
@@ -246,49 +262,39 @@ async def chunks_ready(
246262
chunks: Information about the chunks that are ready.
247263
source_actor: Handle to the scheduling actor owning the chunks.
248264
"""
249-
for it, chunk in enumerate(chunks):
250-
while (chunk.array_name, chunk.timestep) not in self.arrays:
251-
t1 = asyncio.create_task(self.new_pending_array_semaphore.acquire())
252-
t2 = asyncio.create_task(self.new_array_created.wait())
265+
array = self.arrays[array_name]
253266

254-
done, pending = await asyncio.wait([t1, t2], return_when=asyncio.FIRST_COMPLETED)
267+
while timestep not in array.chunk_refs:
268+
t1 = asyncio.create_task(self.new_pending_array_semaphore.acquire())
269+
t2 = asyncio.create_task(self.new_array_created.wait())
255270

256-
for task in pending:
257-
task.cancel()
271+
done, pending = await asyncio.wait([t1, t2], return_when=asyncio.FIRST_COMPLETED)
258272

259-
if t1 in done:
260-
if (chunk.array_name, chunk.timestep) in self.arrays:
261-
# The array was already created by another scheduling actor
262-
self.new_pending_array_semaphore.release()
263-
else:
264-
self.arrays[(chunk.array_name, chunk.timestep)] = _DaskArrayData(
265-
self.arrays_definition[chunk.array_name], chunk.timestep
266-
)
273+
for task in pending:
274+
task.cancel()
267275

268-
self.new_array_created.set()
269-
self.new_array_created.clear()
276+
if t1 in done:
277+
if timestep in array.chunk_refs:
278+
# The array was already created by another scheduling actor
279+
self.new_pending_array_semaphore.release()
280+
else:
281+
array.chunk_refs[timestep] = []
270282

271-
array = self.arrays[(chunk.array_name, chunk.timestep)]
283+
self.new_array_created.set()
284+
self.new_array_created.clear()
272285

273-
# TODO refactor so that the function works with only one array
274-
if it == 0:
275-
array.add_chunk_ref(all_chunks_ref[0])
286+
is_ready = array.add_chunk_ref(all_chunks_ref[0], timestep)
276287

277-
is_ready = array.add_chunk(
278-
chunk.size, chunk.position, chunk.dtype, chunk.nb_chunks_per_dim, scheduling_actor_id
279-
)
280-
281-
if is_ready:
282-
self.arrays_ready.put_nowait(
283-
(
284-
chunk.array_name,
285-
array.timestep,
286-
array.get_full_array(),
287-
)
288+
if is_ready:
289+
self.arrays_ready.put_nowait(
290+
(
291+
array_name,
292+
timestep,
293+
array.get_full_array(timestep),
288294
)
289-
del self.arrays[(chunk.array_name, chunk.timestep)]
295+
)
290296

291-
async def get_next_array(self) -> tuple[str, int, da.Array]:
297+
async def get_next_array(self) -> tuple[str, Timestep, da.Array]:
292298
array = await self.arrays_ready.get()
293299
self.new_pending_array_semaphore.release()
294300
return array

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "doreisa"
3-
version = "0.1.6"
3+
version = "0.2.0"
44
description = ""
55
authors = [{ name = "Adrien Vannson", email = "adrien.vannson@protonmail.com" }]
66
requires-python = ">=3.12"

0 commit comments

Comments
 (0)