|
4 | 4 |
|
5 | 5 | import logging |
6 | 6 | import warnings |
| 7 | +from contextlib import contextmanager |
| 8 | +from contextvars import ContextVar |
7 | 9 | from copy import deepcopy |
8 | 10 | from typing import TYPE_CHECKING |
9 | 11 |
|
@@ -155,6 +157,12 @@ def __init__( |
155 | 157 | self.add_jobs(jobs) |
156 | 158 | self.output = output |
157 | 159 |
|
| 160 | + # If we're running inside a `DecoratedFlow`, add *this* Flow to the |
| 161 | + # context. |
| 162 | + current_flow_children_list = _current_flow_context.get() |
| 163 | + if current_flow_children_list is not None: |
| 164 | + current_flow_children_list.append(self) |
| 165 | + |
158 | 166 | def __len__(self) -> int: |
159 | 167 | """Get the number of jobs or subflows in the flow.""" |
160 | 168 | return len(self.jobs) |
@@ -828,7 +836,7 @@ def add_jobs(self, jobs: Job | Flow | Sequence[Flow | Job]) -> None: |
828 | 836 | if job.host is not None and job.host != self.uuid: |
829 | 837 | raise ValueError( |
830 | 838 | f"{type(job).__name__} {job.name} ({job.uuid}) already belongs " |
831 | | - f"to another flow." |
| 839 | + f"to another flow: {job.host}." |
832 | 840 | ) |
833 | 841 | if job.uuid in job_ids: |
834 | 842 | raise ValueError( |
@@ -921,3 +929,104 @@ def get_flow( |
921 | 929 | ) |
922 | 930 |
|
923 | 931 | return flow |
| 932 | + |
| 933 | + |
| 934 | +class DecoratedFlow(Flow): |
| 935 | + """A DecoratedFlow is a Flow that is returned on using the @flow decorator.""" |
| 936 | + |
| 937 | + def __init__(self, fn, *args, **kwargs): |
| 938 | + from jobflow import Maker |
| 939 | + |
| 940 | + self.fn = fn |
| 941 | + self.args = args |
| 942 | + self.kwargs = kwargs |
| 943 | + |
| 944 | + # Collect the jobs and flows that are used in the function |
| 945 | + children_list = [] |
| 946 | + with flow_build_context(children_list): |
| 947 | + output = self.fn(*self.args, **self.kwargs) |
| 948 | + |
| 949 | + # From the collected items, remove those that have already been assigned |
| 950 | + # to another Flow during the call of the function. |
| 951 | + # This handles the case where a Flow object is instantiated inside |
| 952 | + # the decorated function |
| 953 | + children_list = [c for c in children_list if c.host is None] |
| 954 | + |
| 955 | + name = getattr(self.fn, "__qualname__", self.fn.__name__) |
| 956 | + |
| 957 | + # if decorates a make() in a Maker use that as a name |
| 958 | + if ( |
| 959 | + len(self.args) > 0 |
| 960 | + and name.split(".")[-1] == "make" |
| 961 | + and getattr(args[0], self.fn.__name__, None) |
| 962 | + and isinstance(args[0], Maker) |
| 963 | + ): |
| 964 | + name = args[0].name |
| 965 | + |
| 966 | + if isinstance(output, (jobflow.Job, jobflow.Flow)): |
| 967 | + warnings.warn( |
| 968 | + f"@flow decorated function '{name}' contains a Flow or" |
| 969 | + f"Job as an output. Usually the output should be the output of" |
| 970 | + f"a Job or another Flow (e.g. job.output). Replacing the" |
| 971 | + f"output of the @flow with the output of the Flow/Job." |
| 972 | + f"If this message is unexpected then double check the outputs" |
| 973 | + f"of your @flow decorated function.", |
| 974 | + stacklevel=2, |
| 975 | + ) |
| 976 | + output = output.output |
| 977 | + |
| 978 | + super().__init__(name=name, jobs=children_list, output=output) |
| 979 | + |
| 980 | + |
| 981 | +def flow(fn): |
| 982 | + """ |
| 983 | + Turn a function into a DecoratedFlow object. |
| 984 | +
|
| 985 | + Parameters |
| 986 | + ---------- |
| 987 | + fn (Callable): The function to be wrapped in a DecoratedFlow object. |
| 988 | +
|
| 989 | + Returns |
| 990 | + ------- |
| 991 | + Callable: A wrapper function that, when called, creates and returns |
| 992 | + an instance of DecoratedFlow initialized with the provided function |
| 993 | + and its arguments. |
| 994 | + """ |
| 995 | + from functools import wraps |
| 996 | + |
| 997 | + @wraps(fn) |
| 998 | + def wrapper(*args, **kwargs): |
| 999 | + return DecoratedFlow(fn, *args, **kwargs) |
| 1000 | + |
| 1001 | + return wrapper |
| 1002 | + |
| 1003 | + |
| 1004 | +@contextmanager |
| 1005 | +def flow_build_context(children_list): |
| 1006 | + """Provide a context manager for flows. |
| 1007 | +
|
| 1008 | + Provides a context manager for setting and resetting the `Job` and `Flow` |
| 1009 | + objects in the current flow context. |
| 1010 | +
|
| 1011 | + Parameters |
| 1012 | + ---------- |
| 1013 | + children_list: The `Job` or `Flow` objects that are part of the current |
| 1014 | + flow context. |
| 1015 | +
|
| 1016 | + Yields |
| 1017 | + ------ |
| 1018 | + None: Temporarily sets the provided `Job` or `Flow` objects as |
| 1019 | + belonging to the current flow context within the managed block. |
| 1020 | +
|
| 1021 | + Raises |
| 1022 | + ------ |
| 1023 | + None |
| 1024 | + """ |
| 1025 | + token = _current_flow_context.set(children_list) |
| 1026 | + try: |
| 1027 | + yield |
| 1028 | + finally: |
| 1029 | + _current_flow_context.reset(token) |
| 1030 | + |
| 1031 | + |
| 1032 | +_current_flow_context = ContextVar("current_flow_context", default=None) |
0 commit comments