88from typing import Any , Iterator
99
1010import cascade .executor .serde as serde
11- from cascade .executor .msg import DatasetTransmitPayload
11+ from cascade .executor .msg import DatasetPersistSuccess , DatasetTransmitPayload
1212from cascade .low .core import DatasetId , HostId , TaskId
1313
1414logger = logging .getLogger (__name__ )
1515
1616
1717@dataclass
1818class State :
19- # key add by core.initialize , value add by notify.notify
19+ # key add by core.init_state , value add by notify.notify
2020 outputs : dict [DatasetId , Any ]
21+ # key add by core.init_state, value add by notify.notify
22+ to_persist : set [DatasetId ]
2123 # add by notify.notify, remove by act.flush_queues
2224 fetching_queue : dict [DatasetId , HostId ]
25+ # add by notify.notify, remove by act.flush_queues
26+ persist_queue : dict [DatasetId , HostId ]
2327 # add by notify.notify, removed by act.flush_queues
2428 purging_queue : list [DatasetId ]
2529 # add by core.init_state, remove by notify.notify
@@ -31,13 +35,16 @@ def has_awaitable(self) -> bool:
3135 for e in self .outputs .values ():
3236 if e is None :
3337 return True
38+ if self .to_persist :
39+ return True
3440 return False
3541
3642 def _consider_purge (self , dataset : DatasetId ) -> None :
3743 """If dataset not required anymore, add to purging_queue"""
3844 no_dependants = not self .purging_tracker .get (dataset , None )
3945 not_required_output = self .outputs .get (dataset , 1 ) is not None
40- if no_dependants and not_required_output :
46+ not_required_persist = not dataset in self .to_persist
47+ if all ((no_dependants , not_required_output , not_required_persist )):
4148 logger .debug (f"adding { dataset = } to purging queue" )
4249 if dataset in self .purging_tracker :
4350 self .purging_tracker .pop (dataset )
@@ -52,6 +59,14 @@ def consider_fetch(self, dataset: DatasetId, at: HostId) -> None:
5259 ):
5360 self .fetching_queue [dataset ] = at
5461
62+ def consider_persist (self , dataset : DatasetId , at : HostId ) -> None :
63+ """If required as persist and not yet acknowledged, add to persist queue"""
64+ if (
65+ dataset in self .to_persist
66+ and dataset not in self .persist_queue
67+ ):
68+ self .persist_queue [dataset ] = at
69+
5570 def receive_payload (self , payload : DatasetTransmitPayload ) -> None :
5671 """Stores deserialized value into outputs, considers purge"""
5772 # NOTE ifneedbe get annotation from job.tasks[event.ds.task].definition.output_schema[event.ds.output]
@@ -60,6 +75,11 @@ def receive_payload(self, payload: DatasetTransmitPayload) -> None:
6075 )
6176 self ._consider_purge (payload .header .ds )
6277
78+ def acknowledge_persist (self , payload : DatasetPersistSuccess ) -> None :
79+ """Marks acknowledged, considers purge"""
80+ self .to_persist .discard (payload .ds )
81+ self ._consider_purge (payload .ds )
82+
6383 def task_done (self , task : TaskId , inputs : set [DatasetId ]) -> None :
6484 """Marks that the inputs are not needed for this task anymore, considers purge of each"""
6585 for sourceDataset in inputs :
@@ -76,15 +96,22 @@ def drain_fetching_queue(self) -> Iterator[tuple[DatasetId, HostId]]:
7696 yield dataset , host
7797 self .fetching_queue = {}
7898
99+ def drain_persist_queue (self ) -> Iterator [tuple [DatasetId , HostId ]]:
100+ for dataset , host in self .persist_queue .items ():
101+ yield dataset , host
102+ self .persist_queue = {}
103+
79104
80- def init_state (outputs : set [DatasetId ], edge_o : dict [DatasetId , set [TaskId ]]) -> State :
105+ def init_state (outputs : set [DatasetId ], to_persist : set [ DatasetId ], edge_o : dict [DatasetId , set [TaskId ]]) -> State :
81106 purging_tracker = {
82107 ds : {task for task in dependants } for ds , dependants in edge_o .items ()
83108 }
84109
85110 return State (
86111 outputs = {e : None for e in outputs },
112+ to_persist = {e for e in to_persist },
87113 fetching_queue = {},
88114 purging_queue = [],
89115 purging_tracker = purging_tracker ,
116+ persist_queue = {},
90117 )
0 commit comments