Skip to content

Commit 7f85ef4

Browse files
committed
This commit improves the synchronization of multiple recovery workflows
1 parent 9a93be6 commit 7f85ef4

File tree

9 files changed

+273
-196
lines changed

9 files changed

+273
-196
lines changed

streamflow/core/recovery.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,22 @@ async def recover(self, failed_job: Job, failed_step: Step) -> None: ...
113113

114114

115115
class RetryRequest:
116-
__slots__ = ("job_token", "lock", "output_tokens", "version", "workflow")
116+
__slots__ = (
117+
"job_token",
118+
"lock",
119+
"output_tokens",
120+
"version",
121+
"workflow",
122+
"workflow_ready",
123+
)
117124

118125
def __init__(self) -> None:
119126
self.job_token: JobToken | None = None
120127
self.lock: asyncio.Lock = asyncio.Lock()
121128
self.output_tokens: MutableMapping[str, Token] = {}
122129
self.version: int = 1
123130
self.workflow: Workflow | None = None
131+
self.workflow_ready: asyncio.Event = asyncio.Event()
124132

125133

126134
class TokenAvailability(IntEnum):

streamflow/core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def create_command(
129129

130130

131131
def get_job_step_name(job_name: str) -> str:
132-
return PurePosixPath(job_name).parent.name
132+
return PurePosixPath(job_name).parent.as_posix()
133133

134134

135135
def get_job_tag(job_name: str) -> str:

streamflow/persistence/sqlite.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,9 @@ async def get_token(self, token_id: int) -> MutableMapping[str, Any]:
470470
"WHERE id =:id",
471471
{"id": token_id},
472472
) as cursor:
473-
return _load_keys(dict(await cursor.fetchone()), keys=["value"])
473+
row = _load_keys(dict(await cursor.fetchone()), keys=["value"])
474+
row["recoverable"] = bool(row["recoverable"])
475+
return row
474476

475477
async def get_workflow(self, workflow_id: int) -> MutableMapping[str, Any]:
476478
async with self.connection as db:

streamflow/recovery/policy/recovery.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from streamflow.core.exception import FailureHandlingException
1010
from streamflow.core.recovery import RecoveryPolicy
11-
from streamflow.core.utils import get_tag
12-
from streamflow.core.workflow import Job, Step, Token, Workflow
11+
from streamflow.core.utils import get_job_tag, get_tag
12+
from streamflow.core.workflow import Job, Status, Step, Token, Workflow
1313
from streamflow.log_handler import logger
1414
from streamflow.persistence.loading_context import WorkflowBuilder
1515
from streamflow.recovery.utils import (
@@ -93,7 +93,7 @@ async def _inject_tokens(mapper: GraphMapper, new_workflow: Workflow) -> None:
9393
):
9494
if logger.isEnabledFor(logging.DEBUG):
9595
logger.debug(f"Injecting termination token on port {port.name}")
96-
port.put(TerminationToken())
96+
port.put(TerminationToken(Status.SKIPPED))
9797

9898

9999
async def _populate_workflow(
@@ -111,14 +111,16 @@ async def _populate_workflow(
111111
for step_id in step_ids
112112
)
113113
)
114-
# Add failed step into new_workflow
114+
# Add the failed step to the new workflow
115115
await workflow_builder.load_step(
116116
new_workflow.context,
117117
failed_step.persistent_id,
118118
)
119-
# Instantiate ports capable of moving tokens across workflows
119+
# Instantiate ports that can transfer tokens between workflows
120120
for port in new_workflow.ports.values():
121-
if not isinstance(port, ConnectorPort):
121+
if not isinstance(
122+
port, (ConnectorPort, InterWorkflowJobPort, InterWorkflowPort)
123+
):
122124
new_workflow.create_port(
123125
(
124126
InterWorkflowJobPort
@@ -129,7 +131,7 @@ async def _populate_workflow(
129131
)
130132
for port in failed_step.get_output_ports().values():
131133
cast(InterWorkflowPort, new_workflow.ports[port.name]).add_inter_port(
132-
port, border_tag=get_tag(failed_job.inputs.values())
134+
port, boundary_tag=get_tag(failed_job.inputs.values()), terminate=False
133135
)
134136

135137

@@ -175,21 +177,23 @@ async def _recover_workflow(self, failed_job: Job, failed_step: Step) -> Workflo
175177
]
176178
)
177179
mapper = await create_graph_mapper(self.context, provenance)
178-
# Synchronize across multiple recovery workflows
180+
# Synchronize between multiple recovery workflows
179181
job_tokens = list(
180182
filter(lambda t: isinstance(t, JobToken), mapper.token_instances.values())
181183
)
182-
await self._sync_workflows(
183-
{*(t.value.name for t in job_tokens), failed_job.name},
184-
job_tokens,
185-
mapper,
186-
new_workflow,
184+
job_names = await self._sync_workflows(
185+
job_names={*(t.value.name for t in job_tokens), failed_job.name},
186+
job_tokens=job_tokens,
187+
mapper=mapper,
188+
workflow=new_workflow,
187189
)
188190
# Populate new workflow
189191
steps = await mapper.get_port_and_step_ids(failed_step.output_ports.values())
190192
await _populate_workflow(
191193
steps, failed_step, new_workflow, workflow_builder, failed_job
192194
)
195+
for job_name in job_names:
196+
self.context.failure_manager.get_request(job_name).workflow_ready.set()
193197
await _inject_tokens(mapper, new_workflow)
194198
await _set_step_states(mapper, new_workflow)
195199
return new_workflow
@@ -200,7 +204,8 @@ async def _sync_workflows(
200204
job_tokens: MutableSequence[Token],
201205
mapper: GraphMapper,
202206
workflow: Workflow,
203-
) -> None:
207+
) -> MutableSequence[str]:
208+
new_job_names = []
204209
for job_name in job_names:
205210
retry_request = self.context.failure_manager.get_request(job_name)
206211
if (
@@ -209,20 +214,35 @@ async def _sync_workflows(
209214
)
210215
) == TokenAvailability.FutureAvailable:
211216
job_token = get_job_token(job_name, job_tokens)
212-
# The `retry_request` is the current job running, instead
213-
# the `job_token` is the token to remove in the graph because
214-
# the workflow will depend on the already running job
217+
# `retry_request` represents the currently running job.
218+
# `job_token` refers to the token that needs to be removed from the graph,
219+
# as the workflow depends on the already running job.
215220
if logger.isEnabledFor(logging.DEBUG):
216-
logger.debug(f"Synchronize rollbacks: job {job_name} is running")
217-
# todo: create a unit test for this case
221+
if not (is_wf_ready := retry_request.workflow_ready.is_set()):
222+
logger.debug(
223+
f"Synchronizing rollbacks: Job {job_name} is waiting for the rollback workflow to be ready."
224+
)
225+
else:
226+
logger.debug(
227+
f"Synchronizing rollbacks: Job {job_name} is currently executing."
228+
)
229+
else:
230+
is_wf_ready = True
231+
await retry_request.workflow_ready.wait()
232+
if logger.isEnabledFor(logging.DEBUG) and not is_wf_ready:
233+
logger.debug(
234+
f"Synchronizing rollbacks: Job {job_name} has resumed after the rollback workflow is ready."
235+
)
218236
for port_name in await mapper.get_output_ports(job_token):
219237
if port_name in retry_request.workflow.ports.keys():
220238
cast(
221239
InterWorkflowPort, retry_request.workflow.ports[port_name]
222240
).add_inter_port(
223-
workflow.create_port(cls=InterWorkflowPort, name=port_name)
241+
workflow.create_port(cls=InterWorkflowPort, name=port_name),
242+
boundary_tag=get_job_tag(job_token.value.name),
243+
terminate=True,
224244
)
225-
# Remove tokens recovered in other workflows
245+
# Remove tokens that will be recovered in other workflows
226246
for token_id in await mapper.get_output_tokens(job_token.persistent_id):
227247
mapper.remove_token(token_id, preserve_token=True)
228248
elif is_available == TokenAvailability.Available:
@@ -247,6 +267,9 @@ async def _sync_workflows(
247267
else:
248268
await self.context.failure_manager.update_request(job_name)
249269
retry_request.workflow = workflow
270+
retry_request.workflow_ready.clear()
271+
new_job_names.append(job_name)
272+
return new_job_names
250273

251274
async def recover(self, failed_job: Job, failed_step: Step) -> None:
252275
# Create recover workflow

streamflow/recovery/utils.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def prev(self, vertex: Any) -> MutableSet[Any]:
6868
def remove(self, vertex: Any) -> MutableSequence[Any]:
6969
self.graph.pop(vertex, None)
7070
removed = [vertex]
71-
# Delete nodes which are not connected to the leaves nodes
71+
# Delete the nodes that are not connected to any leaf nodes
7272
dead_end_nodes = set()
7373
for node, values in self.graph.items():
7474
if vertex in values:
@@ -78,7 +78,7 @@ def remove(self, vertex: Any) -> MutableSequence[Any]:
7878
for node in dead_end_nodes:
7979
removed.extend(self.remove(node))
8080

81-
# Assign the root node to vertices without parent
81+
# Assign the root node to the vertices that do not have a parent
8282
orphan_nodes = set()
8383
for node in self.keys():
8484
if node != DirectGraph.ROOT and not self.prev(node):
@@ -266,8 +266,8 @@ async def get_port_and_step_ids(
266266
)
267267
for dependency_row in dependency_rows
268268
}
269-
# Remove steps with some missing input ports
270-
# A port can have multiple input steps. It is necessary to load only the needed steps
269+
# Remove steps with missing input ports
270+
# A port may have multiple input steps, so it is important to load only the necessary steps.
271271
step_to_remove = set()
272272
for step_id, dependency_rows in zip(
273273
step_ids,
@@ -306,7 +306,7 @@ def remove_port(self, port_name: str) -> None:
306306
for token_id in orphan_tokens:
307307
self.remove_token(token_id)
308308

309-
def remove_token(self, token_id: int, preserve_token: bool = True):
309+
def remove_token(self, token_id: int, preserve_token: bool = True) -> None:
310310
if logger.isEnabledFor(logging.INFO):
311311
logger.info(f"Remove token id {token_id}")
312312
if token_id == DirectGraph.ROOT:
@@ -325,13 +325,13 @@ def remove_token(self, token_id: int, preserve_token: bool = True):
325325
token_leaves.add(prev_token_id)
326326
# Delete end-road branches
327327
for leaf_id in token_leaves:
328-
self.remove_token(leaf_id)
328+
self.remove_token(leaf_id, preserve_token=False)
329329
# Delete token (if needed)
330330
if not preserve_token:
331331
self.token_available.pop(token_id, None)
332332
self.token_instances.pop(token_id, None)
333333
self.dag_tokens.remove(token_id)
334-
if not preserve_token:
334+
# Remove ports
335335
empty_ports = set()
336336
for port_name, token_list in self.port_tokens.items():
337337
if token_id in token_list:
@@ -342,10 +342,22 @@ def remove_token(self, token_id: int, preserve_token: bool = True):
342342
self.remove_port(port_name)
343343

344344
def replace_token(self, port_name: str, token: Token, is_available: bool) -> None:
345-
old_token_id = self.get_equal_token(port_name, token)
346-
if old_token_id is None:
347-
raise FailureHandlingException("Impossible replace token")
348-
if logger.isEnabledFor(logging.INFO):
345+
if (old_token_id := self.get_equal_token(port_name, token)) is None:
346+
raise FailureHandlingException(
347+
f"Unable to find a token for replacement with {token.persistent_id}."
348+
)
349+
if old_token_id == token.persistent_id:
350+
if self.token_available[old_token_id] != is_available:
351+
raise FailureHandlingException(
352+
f"Availability mismatch for token {old_token_id}. "
353+
f"Expected: {self.token_available[old_token_id]}, Got: {is_available}."
354+
)
355+
elif logger.isEnabledFor(logging.INFO):
356+
logger.info(
357+
f"Token {old_token_id} is already in desired state. Skipping replacement."
358+
)
359+
return
360+
elif logger.isEnabledFor(logging.INFO):
349361
logger.info(f"Replacing {old_token_id} with {token.persistent_id}")
350362
# Replace
351363
self.dag_tokens.replace(old_token_id, token.persistent_id)
@@ -373,13 +385,13 @@ def add(self, src_token: Token | None, dst_token: Token | None) -> None:
373385
dst_token.persistent_id if dst_token is not None else dst_token,
374386
)
375387

376-
async def build_graph(self, inputs: Iterable[Token]):
388+
async def build_graph(self, inputs: Iterable[Token]) -> None:
377389
"""
378-
The provenance graph represent the execution, and is always a DAG.
379-
Visit the provenance graph with a breadth-first search and is done
380-
backward starting from the input tokens. At the end of the search,
381-
we have a tree where at root there are token which data are available
382-
in some location and leaves will be the input tokens.
390+
The provenance graph represents the execution and is always a DAG.
391+
To traverse the provenance graph, a breadth-first search is performed
392+
starting from the input tokens and moving backward. At the end of the search,
393+
we obtain a tree where the root node represents the tokens whose data are available
394+
in a specific location, and the leaves correspond to the input tokens.
383395
"""
384396
token_frontier = deque(inputs)
385397
loading_context = DefaultDatabaseLoadingContext()

streamflow/workflow/port.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Callable
66

77
from streamflow.core.deployment import Connector
8-
from streamflow.core.workflow import Job, Port, Token, Workflow
8+
from streamflow.core.workflow import Job, Port, Status, Token, Workflow
99
from streamflow.log_handler import logger
1010
from streamflow.workflow.token import TerminationToken
1111

@@ -53,16 +53,23 @@ def put(self, token: Token) -> None:
5353
class InterWorkflowPort(Port):
5454
def __init__(self, workflow: Workflow, name: str):
5555
super().__init__(workflow, name)
56-
self.inter_ports: MutableSequence[tuple[Port, str | None]] = []
56+
self.inter_ports: MutableSequence[tuple[Port, str, bool]] = []
5757

58-
def add_inter_port(self, port: Port, border_tag: str | None = None) -> None:
59-
self.inter_ports.append((port, border_tag))
58+
def add_inter_port(self, port: Port, boundary_tag: str, terminate: bool) -> None:
59+
self.inter_ports.append((port, boundary_tag, terminate))
60+
for token in self.token_list:
61+
if boundary_tag == token.tag:
62+
port.put(token)
63+
if terminate:
64+
port.put(TerminationToken(Status.SKIPPED))
6065

6166
def put(self, token: Token) -> None:
6267
if not isinstance(token, TerminationToken):
63-
for port, border_tag in self.inter_ports:
64-
if border_tag is None or border_tag == token.tag:
68+
for port, boundary_tag, terminate in self.inter_ports:
69+
if boundary_tag == token.tag:
6570
port.put(token)
71+
if terminate:
72+
port.put(TerminationToken(Status.SKIPPED))
6673
super().put(token)
6774

6875

streamflow/workflow/step.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,9 @@ async def run(self) -> None:
10441044
logger.debug(f"Step {self.name} forces gather on key {key}")
10451045

10461046
# Update size_map with the current size
1047-
self.size_map[key] = Token(value=len(self.token_map[key]), tag=key)
1047+
self.size_map[key] = Token(
1048+
value=len(self.token_map[key]), tag=key, recoverable=True
1049+
)
10481050
await self.size_map[key].save(
10491051
self.workflow.context, size_port.persistent_id
10501052
)
@@ -1218,7 +1220,9 @@ async def run(self) -> None:
12181220
)
12191221

12201222
async for schema in self.combinator.combine(task_name, token):
1221-
ins = [id for t in schema.values() for id in t["input_ids"]]
1223+
ins = [
1224+
id_ for t in schema.values() for id_ in t["input_ids"]
1225+
]
12221226
for port_name, token in schema.items():
12231227
self.get_output_port(port_name).put(
12241228
await self._persist_token(
@@ -1652,7 +1656,7 @@ async def _save_additional_params(
16521656
"size_port": self.get_size_port().persistent_id
16531657
}
16541658

1655-
async def _scatter(self, token: Token) -> Token:
1659+
async def _scatter(self, token: Token) -> None:
16561660
if isinstance(token, ListToken):
16571661
output_port = self.get_output_port()
16581662
for i, t in enumerate(token.value):

0 commit comments

Comments
 (0)