Skip to content

Commit 88c98aa

Browse files
Optimize graph sending (#27)
* Optimize graph sending * Fix code quality * Bump to 0.1.6 * Update lockfile
1 parent e1124c1 commit 88c98aa

File tree

4 files changed

+45
-35
lines changed

4 files changed

+45
-35
lines changed

doreisa/_scheduler.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ray
66
from dask.core import get_dependencies
77

8-
from doreisa._scheduling_actor import ChunkRef
8+
from doreisa._scheduling_actor import ChunkRef, ScheduledByOtherActor
99

1010

1111
def doreisa_get(dsk, keys, **kwargs):
@@ -32,7 +32,7 @@ def log(message: str, debug_logs_path: str | None) -> None:
3232

3333
# Find a not too bad scheduling strategy
3434
# Good scheduling in a tree
35-
scheduling = {k: -1 for k in dsk.keys()}
35+
partition = {k: -1 for k in dsk.keys()}
3636

3737
# def explore(key, v: int):
3838
# # Only works for trees for now
@@ -76,32 +76,43 @@ def explore(k) -> int:
7676
val = dsk[k]
7777

7878
if isinstance(val, ChunkRef):
79-
scheduling[k] = val.actor_id
79+
partition[k] = val.actor_id
8080
else:
8181
res = [explore(dep) for dep in get_dependencies(dsk, k)]
82-
scheduling[k] = Counter(res).most_common(1)[0][0]
82+
partition[k] = Counter(res).most_common(1)[0][0]
8383

84-
return scheduling[k]
84+
return partition[k]
8585

8686
explore(key)
8787

8888
log("2. Graph partitionning done", debug_logs_path)
8989

90-
# Pass the scheduling to the scheduling actors
91-
dsk_ref, scheduling_ref = ray.put(dsk), ray.put(scheduling) # noqa: F841
90+
partitionned_graphs: dict[int, dict] = {}
9291

93-
log("3. Graph put in object store", debug_logs_path)
92+
for k, v in dsk.items():
93+
actor_id = partition[k]
94+
95+
if actor_id not in partitionned_graphs:
96+
partitionned_graphs[actor_id] = {}
97+
98+
partitionned_graphs[actor_id][k] = v
99+
100+
for dep in get_dependencies(dsk, k):
101+
if partition[dep] != actor_id:
102+
partitionned_graphs[actor_id][dep] = ScheduledByOtherActor(partition[dep])
103+
104+
log("3. Partitionned graphs created", debug_logs_path)
94105

95106
graph_id = random.randint(0, 2**128 - 1)
96107

97108
ray.get(
98109
[
99-
scheduling_actors[i].store_graph.options(enable_task_events=False).remote(graph_id, dsk_ref, scheduling_ref)
100-
for i in range(len(scheduling_actors))
110+
actor.store_graph.options(enable_task_events=False).remote(graph_id, partitionned_graphs[id])
111+
for id, actor in enumerate(scheduling_actors)
101112
]
102113
)
103114

104-
log("4. Partitionned graph sent", debug_logs_path)
115+
log("4. Partitionned graphs sent", debug_logs_path)
105116

106117
ray.get(
107118
[
@@ -112,7 +123,7 @@ def explore(k) -> int:
112123

113124
log("5. Graph scheduled", debug_logs_path)
114125

115-
res = ray.get(ray.get(scheduling_actors[scheduling[key]].get_value.remote(graph_id, key)))
126+
res = ray.get(ray.get(scheduling_actors[partition[key]].get_value.remote(graph_id, key)))
116127

117128
log("6. End Doreisa scheduler", debug_logs_path)
118129

doreisa/_scheduling_actor.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import ray
77
import ray.actor
88
import ray.util.dask.scheduler
9-
from dask.core import get_dependencies
109

1110
from doreisa import Timestep
1211

@@ -30,6 +29,15 @@ class ChunkRef:
3029
_all_chunks: ray.ObjectRef | None = None
3130

3231

32+
@dataclass
33+
class ScheduledByOtherActor:
34+
"""
35+
Represents a task that is scheduled by another actor in the part of the task graph sent to an actor.
36+
"""
37+
38+
actor_id: int
39+
40+
3341
class GraphInfo:
3442
"""
3543
Information about graphs and their scheduling.
@@ -124,7 +132,7 @@ def __init__(self, actor_id: int) -> None:
124132
# For scheduling
125133
self.new_graph_available = asyncio.Event()
126134
self.graph_infos: dict[int, GraphInfo] = {}
127-
self.partitionned_graphs: dict[int, tuple[dict, dict[str, int]]] = {}
135+
self.partitionned_graphs: dict[int, dict] = {}
128136

129137
def ready(self) -> None:
130138
pass
@@ -188,18 +196,18 @@ async def add_chunk(
188196
else:
189197
await self.chunks_ready_event.wait()
190198

191-
def store_graph(self, graph_id: int, dsk: dict, scheduling: dict[str, int]) -> None:
199+
def store_graph(self, graph_id: int, dsk: dict) -> None:
192200
"""
193201
Store the given graph in the actor until `schedule_graph` is called.
194202
195203
This allows measuring precisely the time it takes to send the graph to all the
196204
actors. If needed, this will be optimized using an efficient communication
197205
method.
198206
"""
199-
self.partitionned_graphs[graph_id] = (dsk, scheduling)
207+
self.partitionned_graphs[graph_id] = dsk
200208

201209
async def schedule_graph(self, graph_id: int):
202-
dsk, scheduling = self.partitionned_graphs.pop(graph_id)
210+
dsk = self.partitionned_graphs.pop(graph_id)
203211

204212
# Find the scheduling actors
205213
if not self.scheduling_actors:
@@ -210,22 +218,13 @@ async def schedule_graph(self, graph_id: int):
210218
self.new_graph_available.set()
211219
self.new_graph_available.clear()
212220

213-
local_keys = {k for k in dsk if scheduling[k] == self.actor_id}
214-
215-
dependency_keys: set[str] = {dep for k in local_keys for dep in get_dependencies(dsk, k)} # type: ignore[assignment]
216-
217-
external_keys = dependency_keys - local_keys
218-
219-
# Filter the dask array
220-
dsk = {k: v for k, v in dsk.items() if k in local_keys}
221-
222-
# Adapt external keys
223-
for k in external_keys:
224-
actor = self.scheduling_actors[scheduling[k]]
225-
dsk[k] = actor.get_value.options(enable_task_events=False).remote(graph_id, k)
226-
227-
# Replace the false chunks by the real ObjectRefs
228221
for key, val in dsk.items():
222+
# Adapt external keys
223+
if isinstance(val, ScheduledByOtherActor):
224+
actor = self.scheduling_actors[val.actor_id]
225+
dsk[key] = actor.get_value.options(enable_task_events=False).remote(graph_id, key)
226+
227+
# Replace the false chunks by the real ObjectRefs
229228
if isinstance(val, ChunkRef):
230229
assert val.actor_id == self.actor_id
231230

@@ -234,7 +233,7 @@ async def schedule_graph(self, graph_id: int):
234233
dsk[key] = pickle.loads(encoded_ref)
235234

236235
# We will need the ObjectRefs of these keys
237-
keys_needed = list(local_keys - dependency_keys)
236+
keys_needed = list(dsk.keys())
238237

239238
refs = await remote_ray_dask_get.remote(dsk, keys_needed)
240239

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "doreisa"
3-
version = "0.1.5"
3+
version = "0.1.6"
44
description = ""
55
authors = [{ name = "Adrien Vannson", email = "adrien.vannson@protonmail.com" }]
66
requires-python = ">=3.12"

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)