Skip to content

Commit b4687b6

Browse files
committed
made task_fields return a dict-like object
1 parent 16fa35a commit b4687b6

File tree

5 files changed

+39
-25
lines changed

5 files changed

+39
-25
lines changed

pydra/compose/base/helpers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ def ensure_field_objects(
7474
name=input_name,
7575
help=input_helps.get(input_name, ""),
7676
)
77-
if is_optional(arg):
78-
inputs[input_name].default = None
7977
elif isinstance(arg, dict):
8078
arg_kwds = copy(arg)
8179
if "help" not in arg_kwds:

pydra/compose/tests/test_workflow.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ def MyTestWorkflow(a, b):
5454
# The constructor function is included as a part of the task so it is
5555
# included in the hash by default and can be overridden if needed. Not 100% sure
5656
# if this is a good idea or not
57-
assert task_fields(MyTestWorkflow) == [
57+
assert list(task_fields(MyTestWorkflow)) == [
5858
workflow.arg(name="a"),
5959
workflow.arg(name="b"),
6060
workflow.arg(
6161
name="constructor", type=ty.Callable, hash_eq=True, default=constructor
6262
),
6363
]
64-
assert task_fields(MyTestWorkflow.Outputs) == [
64+
assert list(task_fields(MyTestWorkflow.Outputs)) == [
6565
workflow.out(name="out"),
6666
]
6767
workflow_spec = MyTestWorkflow(a=1, b=2.0)
@@ -115,15 +115,15 @@ def MyTestShellWorkflow(
115115

116116
constructor = MyTestShellWorkflow().constructor
117117
assert constructor.__name__ == "MyTestShellWorkflow"
118-
assert task_fields(MyTestShellWorkflow) == [
118+
assert list(task_fields(MyTestShellWorkflow)) == [
119119
workflow.arg(name="input_video", type=video.Mp4),
120120
workflow.arg(name="watermark", type=image.Png),
121121
workflow.arg(name="watermark_dims", type=tuple[int, int], default=(10, 10)),
122122
workflow.arg(
123123
name="constructor", type=ty.Callable, hash_eq=True, default=constructor
124124
),
125125
]
126-
assert task_fields(MyTestShellWorkflow.Outputs) == [
126+
assert list(task_fields(MyTestShellWorkflow.Outputs)) == [
127127
workflow.out(name="output_video", type=video.Mp4),
128128
]
129129
input_video = video.Mp4.mock("input.mp4")
@@ -178,7 +178,7 @@ class Outputs(workflow.Outputs):
178178
name="constructor", type=ty.Callable, hash_eq=True, default=constructor
179179
),
180180
]
181-
assert task_fields(MyTestWorkflow.Outputs) == [
181+
assert list(task_fields(MyTestWorkflow.Outputs)) == [
182182
workflow.out(name="out", type=float),
183183
]
184184
workflow_spec = MyTestWorkflow(a=1, b=2.0)
@@ -301,7 +301,7 @@ def MyTestWorkflow(a: int, b: float) -> tuple[float, float]:
301301

302302
return mul.out, divide.divided
303303

304-
assert task_fields(MyTestWorkflow) == [
304+
assert list(task_fields(MyTestWorkflow)) == [
305305
workflow.arg(name="a", type=int, help="An integer input"),
306306
workflow.arg(name="b", type=float, help="A float input"),
307307
workflow.arg(
@@ -311,7 +311,7 @@ def MyTestWorkflow(a: int, b: float) -> tuple[float, float]:
311311
default=MyTestWorkflow().constructor,
312312
),
313313
]
314-
assert task_fields(MyTestWorkflow.Outputs) == [
314+
assert list(task_fields(MyTestWorkflow.Outputs)) == [
315315
workflow.out(name="out1", type=float, help="The first output"),
316316
workflow.out(name="out2", type=float, help="The second output"),
317317
]
@@ -344,7 +344,7 @@ def MyTestWorkflow(a: int, b: float):
344344

345345
# no return is used when the outputs are set directly
346346

347-
assert task_fields(MyTestWorkflow) == [
347+
assert list(task_fields(MyTestWorkflow)) == [
348348
workflow.arg(name="a", type=int),
349349
workflow.arg(name="b", type=float),
350350
workflow.arg(
@@ -354,7 +354,7 @@ def MyTestWorkflow(a: int, b: float):
354354
default=MyTestWorkflow().constructor,
355355
),
356356
]
357-
assert task_fields(MyTestWorkflow.Outputs) == [
357+
assert list(task_fields(MyTestWorkflow.Outputs)) == [
358358
workflow.out(name="out1", type=float),
359359
workflow.out(name="out2", type=float),
360360
]

pydra/engine/tests/test_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def Indirect(a):
8888
assert non_func_values(Direct(a=a)) == non_func_values(Indirect(a=a))
8989

9090
# checking if the annotation is properly converted to output_spec if used in task
91-
assert task_fields(Direct.Outputs)[0] == python.out(name="out", type=int)
91+
assert task_fields(Direct.Outputs).out == python.out(name="out", type=int)
9292

9393

9494
def test_annotation_equivalence_2():
@@ -117,7 +117,7 @@ def Indirect(a) -> tuple[int, float]:
117117
assert hashes(Direct(a=a)) == hashes(Partial(a=a)) == hashes(Indirect(a=a))
118118

119119
# checking if the annotation is properly converted to output_spec if used in task
120-
assert task_fields(Direct.Outputs) == [
120+
assert list(task_fields(Direct.Outputs)) == [
121121
python.out(name="out1", type=int),
122122
python.out(name="out2", type=float),
123123
]
@@ -149,7 +149,7 @@ def Indirect(a):
149149
assert hashes(Direct(a=a)) == hashes(Partial(a=a)) == hashes(Indirect(a=a))
150150

151151
# checking if the annotation is properly converted to output_spec if used in task
152-
assert task_fields(Direct.Outputs)[0] == python.out(name="out1", type=int)
152+
assert task_fields(Direct.Outputs).out1 == python.out(name="out1", type=int)
153153

154154

155155
def test_annotation_equivalence_4():
@@ -184,7 +184,7 @@ def Indirect(a):
184184
assert hashes(Direct(a=a)) == hashes(Partial(a=a)) == hashes(Indirect(a=a))
185185

186186
# checking if the annotation is properly converted to output_spec if used in task
187-
assert task_fields(Direct.Outputs) == [
187+
assert list(task_fields(Direct.Outputs)) == [
188188
python.out(name="sum", type=int),
189189
python.out(name="sub", type=int),
190190
]

pydra/engine/tests/test_shelltask.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ class Outputs(shell.Outputs):
845845

846846
shelly = Shelly(
847847
executable=cmd,
848-
newfile=File.mock("newfile_tmp.txt"),
848+
newfile=tmp_path / "newfile_tmp.txt",
849849
time="02121010",
850850
)
851851

@@ -884,7 +884,7 @@ class Outputs(shell.Outputs):
884884

885885
shelly = Shelly(
886886
executable=cmd,
887-
newfile=File.mock("newfile_tmp.txt"),
887+
newfile=tmp_path / "newfile_tmp.txt",
888888
time="02121010",
889889
)
890890

@@ -2682,7 +2682,7 @@ class Outputs(shell.Outputs):
26822682
shelly = Shelly(
26832683
executable=cmd,
26842684
)
2685-
shelly.file1 = File.mock("new_file_1.txt")
2685+
shelly.file1 = tmp_path / "new_file_1.txt"
26862686
assert get_output_names(shelly) == [
26872687
"newfile1",
26882688
"newfile2",

pydra/utils/general.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,17 +320,33 @@ def attrs_values(obj, **kwargs) -> dict[str, ty.Any]:
320320
}
321321

322322

323-
def task_fields(task: "type[Task] | Task") -> list["Field"]:
323+
class _TaskFieldsList(dict[str, "Field"]):
324+
"""A list of task fields. Acts like list in that you can iterate over the values
325+
but also access them like a dict or by attribute."""
326+
327+
def __iter__(self):
328+
return iter(self.values())
329+
330+
def __getattr__(self, name):
331+
return self[name]
332+
333+
def __dir__(self):
334+
return sorted(self.keys())
335+
336+
337+
def task_fields(task: "type[Task] | Task") -> _TaskFieldsList:
324338
"""List the fields of a task"""
325339
if not inspect.isclass(task):
326340
task = type(task)
327341
if not attrs.has(task):
328-
return []
329-
return [
330-
f.metadata[PYDRA_ATTR_METADATA]
331-
for f in attrs.fields(task)
332-
if PYDRA_ATTR_METADATA in f.metadata
333-
]
342+
return _TaskFieldsList()
343+
return _TaskFieldsList(
344+
**{
345+
f.name: f.metadata[PYDRA_ATTR_METADATA]
346+
for f in attrs.fields(task)
347+
if PYDRA_ATTR_METADATA in f.metadata
348+
}
349+
)
334350

335351

336352
def fields_values(obj, **kwargs) -> dict[str, ty.Any]:

0 commit comments

Comments
 (0)