Skip to content

Commit 33e4074

Browse files
committed
[cascade] add checkpoint persistence implementation
1 parent 72d31d8 commit 33e4074

File tree

15 files changed

+211
-53
lines changed

15 files changed

+211
-53
lines changed

src/cascade/benchmarks/util.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def get_gpu_count(host_idx: int, worker_count: int) -> int:
118118

119119

120120
def launch_executor(
121-
job_instance: JobInstance,
121+
job: JobInstanceRich,
122122
controller_address: BackboneAddress,
123123
workers_per_host: int,
124124
portBase: int,
@@ -138,7 +138,7 @@ def launch_executor(
138138
logger.info(f"will set {gpu_count} gpus on host {i}")
139139
os.environ["CASCADE_GPU_COUNT"] = str(gpu_count)
140140
executor = Executor(
141-
job_instance,
141+
job.jobInstance,
142142
controller_address,
143143
workers_per_host,
144144
f"h{i}",
@@ -156,7 +156,7 @@ def launch_executor(
156156

157157

158158
def run_locally(
159-
job: JobInstance,
159+
job: JobInstanceRich,
160160
hosts: int,
161161
workers: int,
162162
portBase: int = 12345,
@@ -197,7 +197,7 @@ def run_locally(
197197
ps.append(p)
198198

199199
# compute preschedule
200-
preschedule = precompute(job)
200+
preschedule = precompute(job.jobInstance)
201201

202202
# check processes started healthy
203203
for i, p in enumerate(ps):
@@ -243,11 +243,8 @@ def main_local(
243243
log_base: str | None = None,
244244
) -> None:
245245
jobInstanceRich = get_job(job, instance)
246-
if jobInstanceRich.checkpointSpec is not None:
247-
raise NotImplementedError
248-
jobInstance = jobInstanceRich.jobInstance
249246
run_locally(
250-
jobInstance,
247+
jobInstanceRich,
251248
hosts,
252249
workers_per_host,
253250
report_address=report_address,
@@ -272,27 +269,24 @@ def main_dist(
272269
launch = perf_counter_ns()
273270

274271
jobInstanceRich = get_job(job, instance)
275-
if jobInstanceRich.checkpointSpec is not None:
276-
raise NotImplementedError
277-
jobInstance = jobInstanceRich.jobInstance
278272

279273
if idx == 0:
280274
logging.config.dictConfig(logging_config)
281275
tp = ThreadPoolExecutor(max_workers=1)
282-
preschedule_fut = tp.submit(precompute, jobInstance)
276+
preschedule_fut = tp.submit(precompute, jobInstanceRich.jobInstance)
283277
b = Bridge(controller_url, hosts)
284278
preschedule = preschedule_fut.result()
285279
tp.shutdown()
286280
start = perf_counter_ns()
287-
run(jobInstance, b, preschedule, report_address=report_address)
281+
run(jobInstanceRich, b, preschedule, report_address=report_address)
288282
end = perf_counter_ns()
289283
print(
290284
f"compute took {(end-start)/1e9:.3f}s, including startup {(end-launch)/1e9:.3f}s"
291285
)
292286
else:
293287
gpu_count = get_gpu_count(0, workers_per_host)
294288
launch_executor(
295-
jobInstance,
289+
jobInstanceRich,
296290
controller_url,
297291
workers_per_host,
298292
12345,

src/cascade/controller/act.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import logging
1212

13+
import cascade.executor.checkpoints as checkpoints
1314
from cascade.controller.core import State
1415
from cascade.executor.bridge import Bridge
1516
from cascade.executor.msg import TaskSequence
@@ -76,6 +77,12 @@ def flush_queues(bridge: Bridge, state: State, context: JobExecutionContext):
7677
for dataset, host in state.drain_fetching_queue():
7778
bridge.fetch(dataset, host)
7879

80+
for dataset, host in state.drain_persist_queue():
81+
if context.checkpoint_spec is None:
82+
raise TypeError(f"unexpected persist need when checkpoint storage not configured")
83+
persist_params = checkpoints.serialize_persist_params(context.checkpoint_spec)
84+
bridge.persist(dataset, host, context.checkpoint_spec.storage_type, persist_params)
85+
7986
for ds in state.drain_purging_queue():
8087
for host in context.purge_dataset(ds):
8188
logger.debug(f"issuing purge of {ds=} to {host=}")

src/cascade/controller/core.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,22 @@
88
from typing import Any, Iterator
99

1010
import cascade.executor.serde as serde
11-
from cascade.executor.msg import DatasetTransmitPayload
11+
from cascade.executor.msg import DatasetPersistSuccess, DatasetTransmitPayload
1212
from cascade.low.core import DatasetId, HostId, TaskId
1313

1414
logger = logging.getLogger(__name__)
1515

1616

1717
@dataclass
1818
class State:
19-
# key add by core.initialize, value add by notify.notify
19+
# key add by core.init_state, value add by notify.notify
2020
outputs: dict[DatasetId, Any]
21+
# key add by core.init_state, value add by notify.notify
22+
to_persist: set[DatasetId]
2123
# add by notify.notify, remove by act.flush_queues
2224
fetching_queue: dict[DatasetId, HostId]
25+
# add by notify.notify, remove by act.flush_queues
26+
persist_queue: dict[DatasetId, HostId]
2327
# add by notify.notify, removed by act.flush_queues
2428
purging_queue: list[DatasetId]
2529
# add by core.init_state, remove by notify.notify
@@ -31,13 +35,16 @@ def has_awaitable(self) -> bool:
3135
for e in self.outputs.values():
3236
if e is None:
3337
return True
38+
if self.to_persist:
39+
return True
3440
return False
3541

3642
def _consider_purge(self, dataset: DatasetId) -> None:
3743
"""If dataset not required anymore, add to purging_queue"""
3844
no_dependants = not self.purging_tracker.get(dataset, None)
3945
not_required_output = self.outputs.get(dataset, 1) is not None
40-
if no_dependants and not_required_output:
46+
not_required_persist = not dataset in self.to_persist
47+
if all((no_dependants, not_required_output, not_required_persist)):
4148
logger.debug(f"adding {dataset=} to purging queue")
4249
if dataset in self.purging_tracker:
4350
self.purging_tracker.pop(dataset)
@@ -52,6 +59,14 @@ def consider_fetch(self, dataset: DatasetId, at: HostId) -> None:
5259
):
5360
self.fetching_queue[dataset] = at
5461

62+
def consider_persist(self, dataset: DatasetId, at: HostId) -> None:
63+
"""If required as persist and not yet acknowledged, add to persist queue"""
64+
if (
65+
dataset in self.to_persist
66+
and dataset not in self.persist_queue
67+
):
68+
self.persist_queue[dataset] = at
69+
5570
def receive_payload(self, payload: DatasetTransmitPayload) -> None:
5671
"""Stores deserialized value into outputs, considers purge"""
5772
# NOTE ifneedbe get annotation from job.tasks[event.ds.task].definition.output_schema[event.ds.output]
@@ -60,6 +75,11 @@ def receive_payload(self, payload: DatasetTransmitPayload) -> None:
6075
)
6176
self._consider_purge(payload.header.ds)
6277

78+
def acknowledge_persist(self, payload: DatasetPersistSuccess) -> None:
79+
"""Marks acknowledged, considers purge"""
80+
self.to_persist.discard(payload.ds)
81+
self._consider_purge(payload.ds)
82+
6383
def task_done(self, task: TaskId, inputs: set[DatasetId]) -> None:
6484
"""Marks that the inputs are not needed for this task anymore, considers purge of each"""
6585
for sourceDataset in inputs:
@@ -76,15 +96,22 @@ def drain_fetching_queue(self) -> Iterator[tuple[DatasetId, HostId]]:
7696
yield dataset, host
7797
self.fetching_queue = {}
7898

99+
def drain_persist_queue(self) -> Iterator[tuple[DatasetId, HostId]]:
100+
for dataset, host in self.persist_queue.items():
101+
yield dataset, host
102+
self.persist_queue = {}
103+
79104

80-
def init_state(outputs: set[DatasetId], edge_o: dict[DatasetId, set[TaskId]]) -> State:
105+
def init_state(outputs: set[DatasetId], to_persist: set[DatasetId], edge_o: dict[DatasetId, set[TaskId]]) -> State:
81106
purging_tracker = {
82107
ds: {task for task in dependants} for ds, dependants in edge_o.items()
83108
}
84109

85110
return State(
86111
outputs={e: None for e in outputs},
112+
to_persist={e for e in to_persist},
87113
fetching_queue={},
88114
purging_queue=[],
89115
purging_tracker=purging_tracker,
116+
persist_queue={},
90117
)

src/cascade/controller/impl.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from cascade.controller.notify import notify
1515
from cascade.controller.report import Reporter
1616
from cascade.executor.bridge import Bridge, Event
17-
from cascade.low.core import JobInstance, type_dec
17+
from cascade.low.core import JobInstance, JobInstanceRich, type_dec
1818
from cascade.low.execution_context import init_context
1919
from cascade.low.tracing import ControllerPhases, Microtrace, label, mark, timer
2020
from cascade.scheduler.api import assign, init_schedule, plan
@@ -24,7 +24,7 @@
2424

2525

2626
def run(
27-
job: JobInstance,
27+
job: JobInstanceRich,
2828
bridge: Bridge,
2929
preschedule: Preschedule,
3030
report_address: str | None = None,
@@ -34,7 +34,8 @@ def run(
3434
outputs = set(context.job_instance.ext_outputs)
3535
logger.debug(f"starting with {env=} and {report_address=}")
3636
schedule = timer(init_schedule, Microtrace.ctrl_init)(preschedule, context)
37-
state = init_state(outputs, context.edge_o)
37+
to_persist = set(job.checkpointSpec.to_persist) if job.checkpointSpec is not None else set()
38+
state = init_state(outputs, to_persist, context.edge_o)
3839

3940
label("host", "controller")
4041
events: list[Event] = []
@@ -44,7 +45,7 @@ def run(
4445

4546
try:
4647
total_gpus = sum(worker.gpu for worker in env.workers.values())
47-
needs_gpus = any(task.definition.needs_gpu for task in job.tasks.values())
48+
needs_gpus = any(task.definition.needs_gpu for task in job.jobInstance.tasks.values())
4849
if needs_gpus and total_gpus == 0:
4950
raise ValueError("environment contains no gpu yet job demands one")
5051

src/cascade/controller/notify.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from cascade.controller.core import State
1818
from cascade.controller.report import Reporter
1919
from cascade.executor.bridge import Event
20-
from cascade.executor.msg import DatasetPublished, DatasetTransmitPayload
20+
from cascade.executor.msg import DatasetPersistSuccess, DatasetPublished, DatasetTransmitPayload
2121
from cascade.low.core import DatasetId, HostId, WorkerId
2222
from cascade.low.execution_context import DatasetStatus, JobExecutionContext
2323
from cascade.low.func import assert_never
@@ -89,6 +89,7 @@ def notify(
8989
context.host2ds[host][event.ds] = DatasetStatus.available
9090
context.ds2host[event.ds][host] = DatasetStatus.available
9191
state.consider_fetch(event.ds, host)
92+
state.consider_persist(event.ds, host)
9293
consider_computable(schedule, state, context, event.ds, host)
9394
if event.transmit_idx is not None:
9495
mark(
@@ -121,5 +122,7 @@ def notify(
121122
elif isinstance(event, DatasetTransmitPayload):
122123
state.receive_payload(event)
123124
reporter.send_result(event.header.ds, event.value)
125+
elif isinstance(event, DatasetPersistSuccess):
126+
state.acknowledge_persist(event)
124127
else:
125128
assert_never(event)

src/cascade/executor/bridge.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from cascade.executor.executor import heartbeat_grace_ms as executor_heartbeat_grace_ms
1717
from cascade.executor.msg import (
1818
Ack,
19+
DatasetPersistCommand,
20+
DatasetPersistFailure,
21+
DatasetPersistSuccess,
1922
DatasetPublished,
2023
DatasetPurge,
2124
DatasetTransmitCommand,
@@ -29,14 +32,15 @@
2932
TaskFailure,
3033
TaskSequence,
3134
)
32-
from cascade.low.core import DatasetId, Environment, HostId, Worker, WorkerId
35+
from cascade.low.core import CheckpointStorageType, DatasetId, Environment, HostId, Worker, WorkerId
3336
from cascade.low.func import assert_never
3437

3538
logger = logging.getLogger(__name__)
3639

37-
Event = DatasetPublished | DatasetTransmitPayload
38-
ToShutdown = TaskFailure | ExecutorFailure | DatasetTransmitFailure | ExecutorExit
39-
Unsupported = TaskSequence | DatasetPurge | DatasetTransmitCommand | ExecutorShutdown
40+
Event = DatasetPublished | DatasetTransmitPayload | DatasetPersistSuccess
41+
# TODO consider retries here, esp on the PersistFailure
42+
ToShutdown = TaskFailure | ExecutorFailure | DatasetTransmitFailure | DatasetPersistFailure | ExecutorExit
43+
Unsupported = TaskSequence | DatasetPurge | DatasetTransmitCommand | DatasetPersistCommand | ExecutorShutdown
4044

4145

4246
class Bridge:
@@ -158,6 +162,15 @@ def transmit(self, ds: DatasetId, source: HostId, target: HostId) -> None:
158162
self.transmit_idx_counter += 1
159163
self.sender.send("data." + source, m)
160164

165+
def persist(self, ds: DatasetId, source: HostId, storage_type: CheckpointStorageType, persist_params: str) -> None:
166+
m = DatasetPersistCommand(
167+
source=source,
168+
ds=ds,
169+
storage_type=storage_type,
170+
persist_params=persist_params,
171+
)
172+
self.sender.send("data." + source, m)
173+
161174
def fetch(self, ds: DatasetId, source: HostId) -> None:
162175
m = DatasetTransmitCommand(
163176
source=source,
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# (C) Copyright 2025- ECMWF.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
# In applying this licence, ECMWF does not waive the privileges and immunities
6+
# granted to it by virtue of its status as an intergovernmental organisation
7+
# nor does it submit to any jurisdiction.
8+
9+
"""Handles the checkpoint management: storage, retrieval"""
10+
11+
import pathlib
12+
13+
from cascade.executor.msg import DatasetPersistCommand
14+
from cascade.low.core import CheckpointSpec
15+
from cascade.low.func import assert_never
16+
from cascade.shm.client import AllocatedBuffer
17+
18+
19+
def persist_dataset(command: DatasetPersistCommand, buf: AllocatedBuffer) -> None:
20+
match command.storage_type:
21+
case "fs":
22+
root = pathlib.Path(command.persist_params)
23+
root.mkdir(parents=True, exist_ok=True)
24+
file = root / repr(command.ds)
25+
# TODO what about overwrites / concurrent writes? Append uuid?
26+
file.write_bytes(buf.view())
27+
case s:
28+
assert_never(s)
29+
30+
def serialize_persist_params(spec: CheckpointSpec) -> str:
31+
# NOTE we call this every time we store, ideally call this once when building `low.execution_context`
32+
match spec.storage_type:
33+
case "fs":
34+
if not isinstance(spec.storage_params, str):
35+
raise TypeError(f"expected checkpoint storage params to be str, gotten {spec.storage_params.__class__}")
36+
if spec.persist_id is None:
37+
raise TypeError(f"serialize_persist_params called, but persist_id is None")
38+
root = pathlib.Path(spec.storage_params)
39+
return str(root / spec.persist_id)
40+
case s:
41+
assert_never(s)
42+

0 commit comments

Comments
 (0)