Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions src/jobflow/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import typing
import warnings
from dataclasses import dataclass, field

from monty.json import MSONable, jsanitize
Expand Down Expand Up @@ -317,8 +316,6 @@ def __init__(
):
from copy import deepcopy

from jobflow.utils.find import contains_flow_or_job

function_args = () if function_args is None else function_args
function_kwargs = {} if function_kwargs is None else function_kwargs
uuid = suuid() if uuid is None else uuid
Expand Down Expand Up @@ -351,16 +348,17 @@ def __init__(

self.output = OutputReference(self.uuid, output_schema=self.output_schema)

# check to see if job or flow is included in the job args
# this is a possible situation but likely a mistake
all_args = tuple(self.function_args) + tuple(self.function_kwargs.values())
if contains_flow_or_job(all_args):
warnings.warn(
f"Job '{self.name}' contains an Flow or Job as an input. "
f"Usually inputs should be the output of a Job or an Flow (e.g. "
f"job.output). If this message is unexpected then double check the "
f"inputs to your Job."
)
# check to see if job is included in the job args
self.function_args = tuple(
[
arg.output if isinstance(arg, Job) else arg
for arg in list(self.function_args)
]
)
self.function_kwargs = {
arg: v.output if isinstance(v, Job) else v
for arg, v in self.function_kwargs.items()
}

def __repr__(self):
"""Get a string representation of the job."""
Expand Down Expand Up @@ -405,6 +403,44 @@ def __hash__(self) -> int:
"""Get the hash of the job."""
return hash(self.uuid)

def __getitem__(self, key: Any) -> OutputReference:
"""
Get the corresponding `OutputReference` for the `Job`.

This is for when it is indexed like a dictionary or list.

Parameters
----------
key
The index/key.

Returns
-------
OutputReference
The equivalent of `Job.output[k]`
"""
return self.output[key]

def __getattr__(self, name: str) -> OutputReference:
"""
Get the corresponding `OutputReference` for the `Job`.

This is for when it is indexed like a class attribute.

Parameters
----------
name
The name of the attribute.

Returns
-------
OutputReference
The equivalent of `Job.output.name`
"""
if attr := getattr(self.output, name, None):
return attr
raise AttributeError(f"{type(self).__name__} has no attribute {name!r}")

@property
def input_references(self) -> tuple[jobflow.OutputReference, ...]:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/jobflow/utils/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,11 @@ def contains_flow_or_job(obj: Any) -> bool:
from jobflow.core.job import Job

if isinstance(obj, (Flow, Job)):
# if the argument is an flow or job then stop there
# if the argument is a flow or job then stop there
return True

elif isinstance(obj, (float, int, str, bool)):
# argument is a primitive, we won't find an flow or job here
# argument is a primitive, we won't find a flow or job here
return False

obj = jsanitize(obj, strict=True, allow_bson=True)
Expand Down
92 changes: 68 additions & 24 deletions tests/core/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,17 @@ def test_flow_of_jobs_init():
flow = Flow([add_job], output=add_job.output)
assert flow.output == add_job.output

# # test multi job and list multi outputs
# test multi job and list multi outputs
add_job1 = get_test_job()
add_job2 = get_test_job()
flow = Flow([add_job1, add_job2], output=[add_job1.output, add_job2.output])
assert flow.output[1] == add_job2.output

# # test all jobs included needed to generate outputs
# test all jobs included needed to generate outputs
add_job = get_test_job()
with pytest.raises(ValueError):
Flow([], output=add_job.output)

# test job given rather than outputs
add_job = get_test_job()
with pytest.warns(UserWarning):
Flow([add_job], output=add_job)

# test complex object containing job given rather than outputs
add_job = get_test_job()
with pytest.warns(UserWarning):
Flow([add_job], output={1: [[{"a": add_job}]]})

# test job already belongs to another flow
add_job = get_test_job()
Flow([add_job])
Expand Down Expand Up @@ -197,18 +187,6 @@ def test_flow_of_flows_init():
with pytest.raises(ValueError):
Flow([], output=subflow.output)

# test flow given rather than outputs
add_job = get_test_job()
subflow = Flow([add_job], output=add_job.output)
with pytest.warns(UserWarning):
Flow([subflow], output=subflow)

# test complex object containing job given rather than outputs
add_job = get_test_job()
subflow = Flow([add_job], output=add_job.output)
with pytest.warns(UserWarning):
Flow([subflow], output={1: [[{"a": subflow}]]})

# test flow already belongs to another flow
add_job = get_test_job()
subflow = Flow([add_job], output=add_job.output)
Expand Down Expand Up @@ -1012,3 +990,69 @@ def test_flow_repr():
assert len(lines) == len(flow_repr)
for expected, line in zip(lines, flow_repr):
assert line.startswith(expected), f"{line=} doesn't start with {expected=}"


def test_get_item():
from jobflow import Flow, job, run_locally

@job
def make_str(s):
return {"hello": s}

@job
def capitalize(s):
return s.upper()

job1 = make_str("world")
job2 = capitalize(job1["hello"])

flow = Flow([job1, job2])

responses = run_locally(flow, ensure_success=True)
assert responses[job2.uuid][1].output == "WORLD"


def test_get_item_job():
from jobflow import Flow, job, run_locally

@job
def make_str(s):
return s

@job
def capitalize(s):
return s.upper()

job1 = make_str("world")
job2 = capitalize(job1)

flow = Flow([job1, job2])

responses = run_locally(flow, ensure_success=True)
assert responses[job2.uuid][1].output == "WORLD"


def test_get_attr():
from dataclasses import dataclass

from jobflow import Flow, job, run_locally

@dataclass
class MyClass:
hello: str

@job
def make_str(s):
return MyClass(hello=s)

@job
def capitalize(s):
return s.upper()

job1 = make_str("world")
job2 = capitalize(job1.hello)

flow = Flow([job1, job2])

responses = run_locally(flow, ensure_success=True)
assert responses[job2.uuid][1].output == "WORLD"
15 changes: 11 additions & 4 deletions tests/core/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ def test_job_init():
assert test_job.uuid is not None
assert test_job.output.uuid == test_job.uuid

# test job as another job as input
with pytest.warns(UserWarning):
Job(function=add, function_args=(test_job,))

# test init with kwargs
test_job = Job(function=add, function_args=(1,), function_kwargs={"b": 2})
assert test_job
Expand Down Expand Up @@ -1272,6 +1268,7 @@ def use_maker(maker):

def test_job_magic_methods():
from jobflow import Job
from jobflow.core.reference import OutputReference

# prepare test jobs
job1 = Job(function=sum, function_args=([1, 2],))
Expand All @@ -1296,3 +1293,13 @@ def test_job_magic_methods():

# test __hash__
assert hash(job1) != hash(job2) != hash(job3)

# test __getitem__
assert isinstance(job1["test"], OutputReference)
assert isinstance(job1[1], OutputReference)
assert job1["test"].attributes == (("i", "test"),)
assert job1[1].attributes == (("i", 1),)

# test __getattr__
assert isinstance(job1.test, OutputReference)
assert job1.test.attributes == (("a", "test"),)