Skip to content

Commit 3761a0d

Browse files
authored
Merge pull request #815 from vineetbansal/vb/flow_decorator
A flow decorator
2 parents d97a68a + c5c0b2e commit 3761a0d

File tree

5 files changed

+410
-2
lines changed

5 files changed

+410
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ keywords = ["high-throughput", "workflow"]
1010
license = { text = "modified BSD" }
1111
authors = [{ name = "Alex Ganose", email = "a.ganose@imperial.ac.uk" }]
1212
dynamic = ["version"]
13+
1314
classifiers = [
1415
"Development Status :: 5 - Production/Stable",
1516
"Intended Audience :: Information Technology",

src/jobflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Jobflow is a package for writing dynamic and connected workflows."""
22

33
from jobflow._version import __version__
4-
from jobflow.core.flow import Flow, JobOrder
4+
from jobflow.core.flow import Flow, JobOrder, flow
55
from jobflow.core.job import Job, JobConfig, Response, job
66
from jobflow.core.maker import Maker
77
from jobflow.core.reference import OnMissing, OutputReference

src/jobflow/core/flow.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import logging
66
import warnings
7+
from contextlib import contextmanager
8+
from contextvars import ContextVar
79
from copy import deepcopy
810
from typing import TYPE_CHECKING
911

@@ -155,6 +157,12 @@ def __init__(
155157
self.add_jobs(jobs)
156158
self.output = output
157159

160+
# If we're running inside a `DecoratedFlow`, add *this* Flow to the
161+
# context.
162+
current_flow_children_list = _current_flow_context.get()
163+
if current_flow_children_list is not None:
164+
current_flow_children_list.append(self)
165+
158166
def __len__(self) -> int:
159167
"""Get the number of jobs or subflows in the flow."""
160168
return len(self.jobs)
@@ -828,7 +836,7 @@ def add_jobs(self, jobs: Job | Flow | Sequence[Flow | Job]) -> None:
828836
if job.host is not None and job.host != self.uuid:
829837
raise ValueError(
830838
f"{type(job).__name__} {job.name} ({job.uuid}) already belongs "
831-
f"to another flow."
839+
f"to another flow: {job.host}."
832840
)
833841
if job.uuid in job_ids:
834842
raise ValueError(
@@ -921,3 +929,104 @@ def get_flow(
921929
)
922930

923931
return flow
932+
933+
934+
class DecoratedFlow(Flow):
935+
"""A DecoratedFlow is a Flow that is returned on using the @flow decorator."""
936+
937+
def __init__(self, fn, *args, **kwargs):
938+
from jobflow import Maker
939+
940+
self.fn = fn
941+
self.args = args
942+
self.kwargs = kwargs
943+
944+
# Collect the jobs and flows that are used in the function
945+
children_list = []
946+
with flow_build_context(children_list):
947+
output = self.fn(*self.args, **self.kwargs)
948+
949+
# From the collected items, remove those that have already been assigned
950+
# to another Flow during the call of the function.
951+
# This handles the case where a Flow object is instantiated inside
952+
# the decorated function
953+
children_list = [c for c in children_list if c.host is None]
954+
955+
name = getattr(self.fn, "__qualname__", self.fn.__name__)
956+
957+
# if decorates a make() in a Maker use that as a name
958+
if (
959+
len(self.args) > 0
960+
and name.split(".")[-1] == "make"
961+
and getattr(args[0], self.fn.__name__, None)
962+
and isinstance(args[0], Maker)
963+
):
964+
name = args[0].name
965+
966+
if isinstance(output, (jobflow.Job, jobflow.Flow)):
967+
warnings.warn(
968+
f"@flow decorated function '{name}' contains a Flow or"
969+
f"Job as an output. Usually the output should be the output of"
970+
f"a Job or another Flow (e.g. job.output). Replacing the"
971+
f"output of the @flow with the output of the Flow/Job."
972+
f"If this message is unexpected then double check the outputs"
973+
f"of your @flow decorated function.",
974+
stacklevel=2,
975+
)
976+
output = output.output
977+
978+
super().__init__(name=name, jobs=children_list, output=output)
979+
980+
981+
def flow(fn):
982+
"""
983+
Turn a function into a DecoratedFlow object.
984+
985+
Parameters
986+
----------
987+
fn (Callable): The function to be wrapped in a DecoratedFlow object.
988+
989+
Returns
990+
-------
991+
Callable: A wrapper function that, when called, creates and returns
992+
an instance of DecoratedFlow initialized with the provided function
993+
and its arguments.
994+
"""
995+
from functools import wraps
996+
997+
@wraps(fn)
998+
def wrapper(*args, **kwargs):
999+
return DecoratedFlow(fn, *args, **kwargs)
1000+
1001+
return wrapper
1002+
1003+
1004+
@contextmanager
1005+
def flow_build_context(children_list):
1006+
"""Provide a context manager for flows.
1007+
1008+
Provides a context manager for setting and resetting the `Job` and `Flow`
1009+
objects in the current flow context.
1010+
1011+
Parameters
1012+
----------
1013+
children_list: The `Job` or `Flow` objects that are part of the current
1014+
flow context.
1015+
1016+
Yields
1017+
------
1018+
None: Temporarily sets the provided `Job` or `Flow` objects as
1019+
belonging to the current flow context within the managed block.
1020+
1021+
Raises
1022+
------
1023+
None
1024+
"""
1025+
token = _current_flow_context.set(children_list)
1026+
try:
1027+
yield
1028+
finally:
1029+
_current_flow_context.reset(token)
1030+
1031+
1032+
_current_flow_context = ContextVar("current_flow_context", default=None)

src/jobflow/core/job.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from monty.json import MSONable, jsanitize
1212
from typing_extensions import Self
1313

14+
from jobflow.core.flow import _current_flow_context
1415
from jobflow.core.reference import OnMissing, OutputReference
1516
from jobflow.utils.uid import suid
1617

@@ -384,6 +385,12 @@ def __init__(
384385
stacklevel=2,
385386
)
386387

388+
# If we're running inside a `DecoratedFlow`, add *this* Job to the
389+
# context.
390+
current_flow_children_list = _current_flow_context.get()
391+
if current_flow_children_list is not None:
392+
current_flow_children_list.append(self)
393+
387394
def __repr__(self):
388395
"""Get a string representation of the job."""
389396
name, uuid = self.name, self.uuid

0 commit comments

Comments
 (0)