Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions src/cascade/benchmarks/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_gpu_count(host_idx: int, worker_count: int) -> int:


def launch_executor(
job_instance: JobInstance,
job: JobInstanceRich,
controller_address: BackboneAddress,
workers_per_host: int,
portBase: int,
Expand All @@ -138,7 +138,7 @@ def launch_executor(
logger.info(f"will set {gpu_count} gpus on host {i}")
os.environ["CASCADE_GPU_COUNT"] = str(gpu_count)
executor = Executor(
job_instance,
job.jobInstance,
controller_address,
workers_per_host,
f"h{i}",
Expand All @@ -156,7 +156,7 @@ def launch_executor(


def run_locally(
job: JobInstance,
job: JobInstanceRich,
hosts: int,
workers: int,
portBase: int = 12345,
Expand Down Expand Up @@ -197,7 +197,7 @@ def run_locally(
ps.append(p)

# compute preschedule
preschedule = precompute(job)
preschedule = precompute(job.jobInstance)

# check processes started healthy
for i, p in enumerate(ps):
Expand Down Expand Up @@ -243,11 +243,8 @@ def main_local(
log_base: str | None = None,
) -> None:
jobInstanceRich = get_job(job, instance)
if jobInstanceRich.checkpointSpec is not None:
raise NotImplementedError
jobInstance = jobInstanceRich.jobInstance
run_locally(
jobInstance,
jobInstanceRich,
hosts,
workers_per_host,
report_address=report_address,
Expand All @@ -272,27 +269,24 @@ def main_dist(
launch = perf_counter_ns()

jobInstanceRich = get_job(job, instance)
if jobInstanceRich.checkpointSpec is not None:
raise NotImplementedError
jobInstance = jobInstanceRich.jobInstance

if idx == 0:
logging.config.dictConfig(logging_config)
tp = ThreadPoolExecutor(max_workers=1)
preschedule_fut = tp.submit(precompute, jobInstance)
preschedule_fut = tp.submit(precompute, jobInstanceRich.jobInstance)
b = Bridge(controller_url, hosts)
preschedule = preschedule_fut.result()
tp.shutdown()
start = perf_counter_ns()
run(jobInstance, b, preschedule, report_address=report_address)
run(jobInstanceRich, b, preschedule, report_address=report_address)
end = perf_counter_ns()
print(
f"compute took {(end-start)/1e9:.3f}s, including startup {(end-launch)/1e9:.3f}s"
)
else:
gpu_count = get_gpu_count(0, workers_per_host)
launch_executor(
jobInstance,
jobInstanceRich,
controller_url,
workers_per_host,
12345,
Expand Down
7 changes: 7 additions & 0 deletions src/cascade/controller/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import logging

import cascade.executor.checkpoints as checkpoints
from cascade.controller.core import State
from cascade.executor.bridge import Bridge
from cascade.executor.msg import TaskSequence
Expand Down Expand Up @@ -76,6 +77,12 @@ def flush_queues(bridge: Bridge, state: State, context: JobExecutionContext):
for dataset, host in state.drain_fetching_queue():
bridge.fetch(dataset, host)

for dataset, host in state.drain_persist_queue():
if context.checkpoint_spec is None:
raise TypeError(f"unexpected persist need when checkpoint storage not configured")
persist_params = checkpoints.serialize_persist_params(context.checkpoint_spec)
bridge.persist(dataset, host, context.checkpoint_spec.storage_type, persist_params)

for ds in state.drain_purging_queue():
for host in context.purge_dataset(ds):
logger.debug(f"issuing purge of {ds=} to {host=}")
Expand Down
35 changes: 31 additions & 4 deletions src/cascade/controller/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
from typing import Any, Iterator

import cascade.executor.serde as serde
from cascade.executor.msg import DatasetTransmitPayload
from cascade.executor.msg import DatasetPersistSuccess, DatasetTransmitPayload
from cascade.low.core import DatasetId, HostId, TaskId

logger = logging.getLogger(__name__)


@dataclass
class State:
# key add by core.initialize, value add by notify.notify
# key add by core.init_state, value add by notify.notify
outputs: dict[DatasetId, Any]
# key add by core.init_state, value add by notify.notify
to_persist: set[DatasetId]
# add by notify.notify, remove by act.flush_queues
fetching_queue: dict[DatasetId, HostId]
# add by notify.notify, remove by act.flush_queues
persist_queue: dict[DatasetId, HostId]
# add by notify.notify, removed by act.flush_queues
purging_queue: list[DatasetId]
# add by core.init_state, remove by notify.notify
Expand All @@ -31,13 +35,16 @@ def has_awaitable(self) -> bool:
for e in self.outputs.values():
if e is None:
return True
if self.to_persist:
return True
return False

def _consider_purge(self, dataset: DatasetId) -> None:
"""If dataset not required anymore, add to purging_queue"""
no_dependants = not self.purging_tracker.get(dataset, None)
not_required_output = self.outputs.get(dataset, 1) is not None
if no_dependants and not_required_output:
not_required_persist = not dataset in self.to_persist
if all((no_dependants, not_required_output, not_required_persist)):
logger.debug(f"adding {dataset=} to purging queue")
if dataset in self.purging_tracker:
self.purging_tracker.pop(dataset)
Expand All @@ -52,6 +59,14 @@ def consider_fetch(self, dataset: DatasetId, at: HostId) -> None:
):
self.fetching_queue[dataset] = at

def consider_persist(self, dataset: DatasetId, at: HostId) -> None:
"""If required as persist and not yet acknowledged, add to persist queue"""
if (
dataset in self.to_persist
and dataset not in self.persist_queue
):
self.persist_queue[dataset] = at

def receive_payload(self, payload: DatasetTransmitPayload) -> None:
"""Stores deserialized value into outputs, considers purge"""
# NOTE ifneedbe get annotation from job.tasks[event.ds.task].definition.output_schema[event.ds.output]
Expand All @@ -60,6 +75,11 @@ def receive_payload(self, payload: DatasetTransmitPayload) -> None:
)
self._consider_purge(payload.header.ds)

def acknowledge_persist(self, payload: DatasetPersistSuccess) -> None:
"""Marks acknowledged, considers purge"""
self.to_persist.discard(payload.ds)
self._consider_purge(payload.ds)

def task_done(self, task: TaskId, inputs: set[DatasetId]) -> None:
"""Marks that the inputs are not needed for this task anymore, considers purge of each"""
for sourceDataset in inputs:
Expand All @@ -76,15 +96,22 @@ def drain_fetching_queue(self) -> Iterator[tuple[DatasetId, HostId]]:
yield dataset, host
self.fetching_queue = {}

def drain_persist_queue(self) -> Iterator[tuple[DatasetId, HostId]]:
for dataset, host in self.persist_queue.items():
yield dataset, host
self.persist_queue = {}


def init_state(outputs: set[DatasetId], edge_o: dict[DatasetId, set[TaskId]]) -> State:
def init_state(outputs: set[DatasetId], to_persist: set[DatasetId], edge_o: dict[DatasetId, set[TaskId]]) -> State:
purging_tracker = {
ds: {task for task in dependants} for ds, dependants in edge_o.items()
}

return State(
outputs={e: None for e in outputs},
to_persist={e for e in to_persist},
fetching_queue={},
purging_queue=[],
purging_tracker=purging_tracker,
persist_queue={},
)
9 changes: 5 additions & 4 deletions src/cascade/controller/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from cascade.controller.notify import notify
from cascade.controller.report import Reporter
from cascade.executor.bridge import Bridge, Event
from cascade.low.core import JobInstance, type_dec
from cascade.low.core import JobInstance, JobInstanceRich, type_dec
from cascade.low.execution_context import init_context
from cascade.low.tracing import ControllerPhases, Microtrace, label, mark, timer
from cascade.scheduler.api import assign, init_schedule, plan
Expand All @@ -24,7 +24,7 @@


def run(
job: JobInstance,
job: JobInstanceRich,
bridge: Bridge,
preschedule: Preschedule,
report_address: str | None = None,
Expand All @@ -34,7 +34,8 @@ def run(
outputs = set(context.job_instance.ext_outputs)
logger.debug(f"starting with {env=} and {report_address=}")
schedule = timer(init_schedule, Microtrace.ctrl_init)(preschedule, context)
state = init_state(outputs, context.edge_o)
to_persist = set(job.checkpointSpec.to_persist) if job.checkpointSpec is not None else set()
state = init_state(outputs, to_persist, context.edge_o)

label("host", "controller")
events: list[Event] = []
Expand All @@ -44,7 +45,7 @@ def run(

try:
total_gpus = sum(worker.gpu for worker in env.workers.values())
needs_gpus = any(task.definition.needs_gpu for task in job.tasks.values())
needs_gpus = any(task.definition.needs_gpu for task in job.jobInstance.tasks.values())
if needs_gpus and total_gpus == 0:
raise ValueError("environment contains no gpu yet job demands one")

Expand Down
5 changes: 4 additions & 1 deletion src/cascade/controller/notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from cascade.controller.core import State
from cascade.controller.report import Reporter
from cascade.executor.bridge import Event
from cascade.executor.msg import DatasetPublished, DatasetTransmitPayload
from cascade.executor.msg import DatasetPersistSuccess, DatasetPublished, DatasetTransmitPayload
from cascade.low.core import DatasetId, HostId, WorkerId
from cascade.low.execution_context import DatasetStatus, JobExecutionContext
from cascade.low.func import assert_never
Expand Down Expand Up @@ -89,6 +89,7 @@ def notify(
context.host2ds[host][event.ds] = DatasetStatus.available
context.ds2host[event.ds][host] = DatasetStatus.available
state.consider_fetch(event.ds, host)
state.consider_persist(event.ds, host)
consider_computable(schedule, state, context, event.ds, host)
if event.transmit_idx is not None:
mark(
Expand Down Expand Up @@ -121,5 +122,7 @@ def notify(
elif isinstance(event, DatasetTransmitPayload):
state.receive_payload(event)
reporter.send_result(event.header.ds, event.value)
elif isinstance(event, DatasetPersistSuccess):
state.acknowledge_persist(event)
else:
assert_never(event)
21 changes: 17 additions & 4 deletions src/cascade/executor/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from cascade.executor.executor import heartbeat_grace_ms as executor_heartbeat_grace_ms
from cascade.executor.msg import (
Ack,
DatasetPersistCommand,
DatasetPersistFailure,
DatasetPersistSuccess,
DatasetPublished,
DatasetPurge,
DatasetTransmitCommand,
Expand All @@ -29,14 +32,15 @@
TaskFailure,
TaskSequence,
)
from cascade.low.core import DatasetId, Environment, HostId, Worker, WorkerId
from cascade.low.core import CheckpointStorageType, DatasetId, Environment, HostId, Worker, WorkerId
from cascade.low.func import assert_never

logger = logging.getLogger(__name__)

Event = DatasetPublished | DatasetTransmitPayload
ToShutdown = TaskFailure | ExecutorFailure | DatasetTransmitFailure | ExecutorExit
Unsupported = TaskSequence | DatasetPurge | DatasetTransmitCommand | ExecutorShutdown
Event = DatasetPublished | DatasetTransmitPayload | DatasetPersistSuccess
# TODO consider retries here, esp on the PersistFailure
ToShutdown = TaskFailure | ExecutorFailure | DatasetTransmitFailure | DatasetPersistFailure | ExecutorExit
Unsupported = TaskSequence | DatasetPurge | DatasetTransmitCommand | DatasetPersistCommand | ExecutorShutdown


class Bridge:
Expand Down Expand Up @@ -158,6 +162,15 @@ def transmit(self, ds: DatasetId, source: HostId, target: HostId) -> None:
self.transmit_idx_counter += 1
self.sender.send("data." + source, m)

def persist(self, ds: DatasetId, source: HostId, storage_type: CheckpointStorageType, persist_params: str) -> None:
m = DatasetPersistCommand(
source=source,
ds=ds,
storage_type=storage_type,
persist_params=persist_params,
)
self.sender.send("data." + source, m)

def fetch(self, ds: DatasetId, source: HostId) -> None:
m = DatasetTransmitCommand(
source=source,
Expand Down
42 changes: 42 additions & 0 deletions src/cascade/executor/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# (C) Copyright 2025- ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""Handles the checkpoint management: storage, retrieval"""

import pathlib

from cascade.executor.msg import DatasetPersistCommand
from cascade.low.core import CheckpointSpec
from cascade.low.func import assert_never
from cascade.shm.client import AllocatedBuffer


def persist_dataset(command: DatasetPersistCommand, buf: AllocatedBuffer) -> None:
match command.storage_type:
case "fs":
root = pathlib.Path(command.persist_params)
root.mkdir(parents=True, exist_ok=True)
file = root / repr(command.ds)
# TODO what about overwrites / concurrent writes? Append uuid?
file.write_bytes(buf.view())
case s:
assert_never(s)

def serialize_persist_params(spec: CheckpointSpec) -> str:
# NOTE we call this every time we store, ideally call this once when building `low.execution_context`
match spec.storage_type:
case "fs":
if not isinstance(spec.storage_params, str):
raise TypeError(f"expected checkpoint storage params to be str, gotten {spec.storage_params.__class__}")
if spec.persist_id is None:
raise TypeError(f"serialize_persist_params called, but persist_id is None")
root = pathlib.Path(spec.storage_params)
return str(root / spec.persist_id)
case s:
assert_never(s)

Loading
Loading