diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index 3af7859c..437d06ca 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -4,7 +4,6 @@ import logging import typing -import warnings from dataclasses import dataclass, field from monty.json import MSONable, jsanitize @@ -317,8 +316,6 @@ def __init__( ): from copy import deepcopy - from jobflow.utils.find import contains_flow_or_job - function_args = () if function_args is None else function_args function_kwargs = {} if function_kwargs is None else function_kwargs uuid = suuid() if uuid is None else uuid @@ -351,16 +348,17 @@ def __init__( self.output = OutputReference(self.uuid, output_schema=self.output_schema) - # check to see if job or flow is included in the job args - # this is a possible situation but likely a mistake - all_args = tuple(self.function_args) + tuple(self.function_kwargs.values()) - if contains_flow_or_job(all_args): - warnings.warn( - f"Job '{self.name}' contains an Flow or Job as an input. " - f"Usually inputs should be the output of a Job or an Flow (e.g. " - f"job.output). If this message is unexpected then double check the " - f"inputs to your Job." - ) + # check to see if job is included in the job args + self.function_args = tuple( + [ + arg.output if isinstance(arg, Job) else arg + for arg in list(self.function_args) + ] + ) + self.function_kwargs = { + arg: v.output if isinstance(v, Job) else v + for arg, v in self.function_kwargs.items() + } def __repr__(self): """Get a string representation of the job.""" @@ -405,6 +403,44 @@ def __hash__(self) -> int: """Get the hash of the job.""" return hash(self.uuid) + def __getitem__(self, key: Any) -> OutputReference: + """ + Get the corresponding `OutputReference` for the `Job`. + + This is for when it is indexed like a dictionary or list. + + Parameters + ---------- + key + The index/key. + + Returns + ------- + OutputReference + The equivalent of `Job.output[k]` + """ + return self.output[key] + + def __getattr__(self, name: str) -> OutputReference: + """ + Get the corresponding `OutputReference` for the `Job`. + + This is for when it is indexed like a class attribute. + + Parameters + ---------- + name + The name of the attribute. + + Returns + ------- + OutputReference + The equivalent of `Job.output.name` + """ + if attr := getattr(self.output, name, None): + return attr + raise AttributeError(f"{type(self).__name__} has no attribute {name!r}") + @property def input_references(self) -> tuple[jobflow.OutputReference, ...]: """ diff --git a/src/jobflow/utils/find.py b/src/jobflow/utils/find.py index cb1a2b82..31ca4a5e 100644 --- a/src/jobflow/utils/find.py +++ b/src/jobflow/utils/find.py @@ -199,11 +199,11 @@ def contains_flow_or_job(obj: Any) -> bool: from jobflow.core.job import Job if isinstance(obj, (Flow, Job)): - # if the argument is an flow or job then stop there + # if the argument is a flow or job then stop there return True elif isinstance(obj, (float, int, str, bool)): - # argument is a primitive, we won't find an flow or job here + # argument is a primitive, we won't find a flow or job here return False obj = jsanitize(obj, strict=True, allow_bson=True) diff --git a/tests/core/test_flow.py b/tests/core/test_flow.py index d34fc2ae..4899137f 100644 --- a/tests/core/test_flow.py +++ b/tests/core/test_flow.py @@ -101,27 +101,17 @@ def test_flow_of_jobs_init(): flow = Flow([add_job], output=add_job.output) assert flow.output == add_job.output - # # test multi job and list multi outputs + # test multi job and list multi outputs add_job1 = get_test_job() add_job2 = get_test_job() flow = Flow([add_job1, add_job2], output=[add_job1.output, add_job2.output]) assert flow.output[1] == add_job2.output - # # test all jobs included needed to generate outputs + # test all jobs included needed to generate outputs add_job = get_test_job() with pytest.raises(ValueError): Flow([], output=add_job.output) - # test job given rather than outputs - add_job = get_test_job() - with pytest.warns(UserWarning): - Flow([add_job], output=add_job) - - # test complex object containing job given rather than outputs - add_job = get_test_job() - with pytest.warns(UserWarning): - Flow([add_job], output={1: [[{"a": add_job}]]}) - # test job already belongs to another flow add_job = get_test_job() Flow([add_job]) @@ -197,18 +187,6 @@ def test_flow_of_flows_init(): with pytest.raises(ValueError): Flow([], output=subflow.output) - # test flow given rather than outputs - add_job = get_test_job() - subflow = Flow([add_job], output=add_job.output) - with pytest.warns(UserWarning): - Flow([subflow], output=subflow) - - # test complex object containing job given rather than outputs - add_job = get_test_job() - subflow = Flow([add_job], output=add_job.output) - with pytest.warns(UserWarning): - Flow([subflow], output={1: [[{"a": subflow}]]}) - # test flow already belongs to another flow add_job = get_test_job() subflow = Flow([add_job], output=add_job.output) @@ -1012,3 +990,69 @@ def test_flow_repr(): assert len(lines) == len(flow_repr) for expected, line in zip(lines, flow_repr): assert line.startswith(expected), f"{line=} doesn't start with {expected=}" + + +def test_get_item(): + from jobflow import Flow, job, run_locally + + @job + def make_str(s): + return {"hello": s} + + @job + def capitalize(s): + return s.upper() + + job1 = make_str("world") + job2 = capitalize(job1["hello"]) + + flow = Flow([job1, job2]) + + responses = run_locally(flow, ensure_success=True) + assert responses[job2.uuid][1].output == "WORLD" + + +def test_get_item_job(): + from jobflow import Flow, job, run_locally + + @job + def make_str(s): + return s + + @job + def capitalize(s): + return s.upper() + + job1 = make_str("world") + job2 = capitalize(job1) + + flow = Flow([job1, job2]) + + responses = run_locally(flow, ensure_success=True) + assert responses[job2.uuid][1].output == "WORLD" + + +def test_get_attr(): + from dataclasses import dataclass + + from jobflow import Flow, job, run_locally + + @dataclass + class MyClass: + hello: str + + @job + def make_str(s): + return MyClass(hello=s) + + @job + def capitalize(s): + return s.upper() + + job1 = make_str("world") + job2 = capitalize(job1.hello) + + flow = Flow([job1, job2]) + + responses = run_locally(flow, ensure_success=True) + assert responses[job2.uuid][1].output == "WORLD" diff --git a/tests/core/test_job.py b/tests/core/test_job.py index 3d862565..77fae782 100644 --- a/tests/core/test_job.py +++ b/tests/core/test_job.py @@ -32,10 +32,6 @@ def test_job_init(): assert test_job.uuid is not None assert test_job.output.uuid == test_job.uuid - # test job as another job as input - with pytest.warns(UserWarning): - Job(function=add, function_args=(test_job,)) - # test init with kwargs test_job = Job(function=add, function_args=(1,), function_kwargs={"b": 2}) assert test_job @@ -1272,6 +1268,7 @@ def use_maker(maker): def test_job_magic_methods(): from jobflow import Job + from jobflow.core.reference import OutputReference # prepare test jobs job1 = Job(function=sum, function_args=([1, 2],)) @@ -1296,3 +1293,13 @@ def test_job_magic_methods(): # test __hash__ assert hash(job1) != hash(job2) != hash(job3) + + # test __getitem__ + assert isinstance(job1["test"], OutputReference) + assert isinstance(job1[1], OutputReference) + assert job1["test"].attributes == (("i", "test"),) + assert job1[1].attributes == (("i", 1),) + + # test __getattr__ + assert isinstance(job1.test, OutputReference) + assert job1.test.attributes == (("a", "test"),)