Skip to content

Commit bdee4c8

Browse files
committed
cleaned up specs.py so that it works with new syntax
1 parent 032fd4e commit bdee4c8

File tree

13 files changed

+204
-299
lines changed

13 files changed

+204
-299
lines changed

pydra/design/base.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
"make_task_spec",
3737
]
3838

39-
RESERVED_OUTPUT_NAMES = ("split", "combine")
40-
4139

4240
class _Empty(enum.Enum):
4341

@@ -58,6 +56,11 @@ def is_type(_, __, val: ty.Any) -> bool:
5856
return inspect.isclass(val) or ty.get_origin(val)
5957

6058

59+
def convert_default_value(value: ty.Any, self_: "Field") -> ty.Any:
60+
"""Ensure the default value has been coerced into the correct type"""
61+
return TypeParser[self_.type](self_.type, label=self_.name)(value)
62+
63+
6164
@attrs.define(kw_only=True)
6265
class Field:
6366
"""Base class for input and output fields to task specifications
@@ -66,9 +69,11 @@ class Field:
6669
----------
6770
name: str, optional
6871
The name of the field, used when specifying a list of fields instead of a mapping
69-
from name to field, by default it is None
7072
type: type, optional
7173
The type of the field, by default it is Any
74+
from name to field, by default it is None
75+
default : Any, optional
76+
the default value for the field, by default it is EMPTY
7277
help_string: str, optional
7378
A short description of the input field.
7479
requires: list, optional
@@ -83,6 +88,9 @@ class Field:
8388
type: ty.Type[ty.Any] = attrs.field(
8489
validator=is_type, default=ty.Any, converter=default_if_none(ty.Any)
8590
)
91+
default: ty.Any = attrs.field(
92+
default=EMPTY, converter=attrs.Converter(convert_default_value, with_self=True)
93+
)
8694
help_string: str = ""
8795
requires: list[str] | list[list[str]] = attrs.field(
8896
factory=list, converter=ensure_list
@@ -97,10 +105,15 @@ class Arg(Field):
97105
98106
Parameters
99107
----------
108+
name: str, optional
109+
The name of the field, used when specifying a list of fields instead of a mapping
110+
from name to field, by default it is None
111+
type: type, optional
112+
The type of the field, by default it is Any
113+
default : Any, optional
114+
the default value for the field, by default it is EMPTY
100115
help_string: str
101116
A short description of the input field.
102-
default : Any, optional
103-
the default value for the argument
104117
allowed_values: list, optional
105118
List of allowed values for the field.
106119
requires: list, optional
@@ -118,14 +131,8 @@ class Arg(Field):
118131
If True the input field can’t be provided by the user but it aggregates other
119132
input fields (for example the fields with argstr: -o {fldA} {fldB}), by default
120133
it is False
121-
type: type, optional
122-
The type of the field, by default it is Any
123-
name: str, optional
124-
The name of the field, used when specifying a list of fields instead of a mapping
125-
from name to field, by default it is None
126134
"""
127135

128-
default: ty.Any = EMPTY
129136
allowed_values: list | None = None
130137
xor: list | None = None
131138
copy_mode: File.CopyMode = File.CopyMode.any
@@ -145,6 +152,8 @@ class Out(Field):
145152
from name to field, by default it is None
146153
type: type, optional
147154
The type of the field, by default it is Any
155+
default : Any, optional
156+
the default value for the field, by default it is EMPTY
148157
help_string: str, optional
149158
A short description of the input field.
150159
requires: list, optional
@@ -385,7 +394,7 @@ def make_outputs_spec(
385394
f"Cannot make {spec_type} output spec from {out_spec_bases} bases"
386395
)
387396
outputs_bases = bases + (spec_type,)
388-
if reserved_names := [n for n in outputs if n in RESERVED_OUTPUT_NAMES]:
397+
if reserved_names := [n for n in outputs if n in spec_type.RESERVED_FIELD_NAMES]:
389398
raise ValueError(
390399
f"{reserved_names} are reserved and cannot be used for output field names"
391400
)

pydra/design/shell.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from fileformats import generic
1313
from fileformats.core.exceptions import FormatRecognitionError
1414
from pydra.engine.specs import ShellSpec, ShellOutSpec
15+
from pydra.engine.helpers import attrs_values
1516
from .base import (
1617
Arg,
1718
Out,
@@ -470,7 +471,7 @@ def add_arg(name, field_type, kwds, is_option=False):
470471
kwds["type"] = field
471472
field = field_type(name=name, **kwds)
472473
elif not isinstance(field, field_type): # If field type is outarg not out
473-
field = field_type(**attrs.asdict(field, recurse=False))
474+
field = field_type(**attrs_values(field))
474475
field.name = name
475476
type_ = kwds.pop("type", field.type)
476477
if field.type is ty.Any:

pydra/engine/audit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import os
44
import json
5-
import attr
65
from pydra.utils.messenger import send_message, make_message, gen_uuid, now, AuditFlag
6+
from pydra.engine.helpers import attrs_values
77
from fileformats.core import FileSet
88
from pydra.utils.hash import hash_function
99

@@ -104,7 +104,7 @@ def finalize_audit(self, result):
104104
)
105105
# audit resources/runtime information
106106
self.eid = f"uid:{gen_uuid()}"
107-
entity = attr.asdict(result.runtime, recurse=False)
107+
entity = attrs_values(result.runtime)
108108
entity.update(
109109
**{
110110
"@id": self.eid,
@@ -180,12 +180,12 @@ def audit_check(self, flag):
180180

181181
def audit_task(self, task):
182182
import subprocess as sp
183-
from .helpers import attr_fields
183+
from .helpers import attrs_fields
184184

185185
label = task.name
186186

187187
command = task.cmdline if hasattr(task.inputs, "executable") else None
188-
attr_list = attr_fields(task.inputs)
188+
attr_list = attrs_fields(task.inputs)
189189
for attrs in attr_list:
190190
input_name = attrs.name
191191
value = getattr(task.inputs, input_name)

pydra/engine/boutiques.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from pydra.utils.messenger import AuditFlag
99
from pydra.engine.task import ShellCommandTask
10-
from pydra.engine.specs import SpecInfo, ShellSpec, ShellOutSpec, File, attr_fields
10+
from pydra.engine.specs import SpecInfo, ShellSpec, ShellOutSpec, File, attrs_fields
1111
from .helpers_file import is_local_file
1212

1313

@@ -192,7 +192,7 @@ def _command_args_single(self, state_ind=None, index=None):
192192
def _bosh_invocation_file(self, state_ind=None, index=None):
193193
"""creating bosh invocation file - json file with inputs values"""
194194
input_json = {}
195-
for f in attr_fields(self.inputs, exclude_names=("executable", "args")):
195+
for f in attrs_fields(self.inputs, exclude_names=("executable", "args")):
196196
if self.state and f"{self.name}.{f.name}" in state_ind:
197197
value = getattr(self.inputs, f.name)[state_ind[f"{self.name}.{f.name}"]]
198198
else:

pydra/engine/core.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from .workflow.lazy import is_lazy
2727
from .helpers import (
2828
create_checksum,
29-
attr_fields,
29+
attrs_fields,
30+
attrs_values,
3031
print_help,
3132
load_result,
3233
save,
@@ -342,7 +343,7 @@ def generated_output_names(self):
342343
The results depends on the input provided to the task
343344
"""
344345
output_klass = self.interface.Outputs
345-
if hasattr(output_klass, "generated_output_names"):
346+
if hasattr(output_klass, "_generated_output_names"):
346347
output = output_klass(
347348
**{f.name: attr.NOTHING for f in attr.fields(output_klass)}
348349
)
@@ -352,7 +353,7 @@ def generated_output_names(self):
352353
if modified_inputs:
353354
_inputs = attr.evolve(_inputs, **modified_inputs)
354355

355-
return output.generated_output_names(
356+
return output._generated_output_names(
356357
inputs=_inputs, output_dir=self.output_dir
357358
)
358359
else:
@@ -461,9 +462,7 @@ def _modify_inputs(self):
461462
from pydra.utils.typing import TypeParser
462463

463464
orig_inputs = {
464-
k: v
465-
for k, v in attr.asdict(self.inputs, recurse=False).items()
466-
if not k.startswith("_")
465+
k: v for k, v in attrs_values(self.inputs).items() if not k.startswith("_")
467466
}
468467
map_copyfiles = {}
469468
input_fields = attr.fields(type(self.inputs))
@@ -754,7 +753,7 @@ def result(self, state_index=None, return_inputs=False):
754753

755754
def _reset(self):
756755
"""Reset the connections between inputs and LazyFields."""
757-
for field in attr_fields(self.inputs):
756+
for field in attrs_fields(self.inputs):
758757
if field.name in self.inp_lf:
759758
setattr(self.inputs, field.name, self.inp_lf[field.name])
760759
if is_workflow(self):
@@ -979,7 +978,7 @@ def create_connections(self, task, detailed=False):
979978
"""
980979
# TODO: create connection is run twice
981980
other_states = {}
982-
for field in attr_fields(task.inputs):
981+
for field in attrs_fields(task.inputs):
983982
val = getattr(task.inputs, field.name)
984983
if is_lazy(val):
985984
# saving all connections with LazyFields

pydra/engine/helpers.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,18 @@
2626
PYDRA_ATTR_METADATA = "__PYDRA_METADATA__"
2727

2828

29-
def attr_fields(spec, exclude_names=()):
29+
def attrs_fields(spec, exclude_names=()) -> list[attrs.Attribute]:
30+
"""Get the fields of a spec, excluding some names."""
3031
return [field for field in spec.__attrs_attrs__ if field.name not in exclude_names]
3132

3233

34+
def attrs_values(obj, **kwargs) -> dict[str, ty.Any]:
35+
"""Get the values of an attrs object."""
36+
return attrs.asdict(obj, recurse=False, **kwargs)
37+
38+
3339
def list_fields(interface: "TaskSpec") -> list["Field"]:
40+
"""List the fields of a task specification"""
3441
if not attrs.has(interface):
3542
return []
3643
return [
@@ -43,7 +50,7 @@ def list_fields(interface: "TaskSpec") -> list["Field"]:
4350
# from .specs import MultiInputFile, MultiInputObj, MultiOutputObj, MultiOutputFile
4451

4552

46-
def from_list_if_single(obj):
53+
def from_list_if_single(obj: ty.Any) -> ty.Any:
4754
"""Converts a list to a single item if it is of length == 1"""
4855

4956
if obj is attrs.NOTHING:
@@ -109,7 +116,7 @@ def load_result(checksum, cache_locations):
109116
return None
110117

111118

112-
def save(task_path: Path, result=None, task=None, name_prefix=None):
119+
def save(task_path: Path, result=None, task=None, name_prefix=None) -> None:
113120
"""
114121
Save a :class:`~pydra.engine.core.TaskBase` object and/or results.
115122
@@ -147,7 +154,7 @@ def save(task_path: Path, result=None, task=None, name_prefix=None):
147154

148155
def copyfile_workflow(wf_path: os.PathLike, result):
149156
"""if file in the wf results, the file will be copied to the workflow directory"""
150-
for field in attr_fields(result.output):
157+
for field in attrs_fields(result.output):
151158
value = getattr(result.output, field.name)
152159
# if the field is a path or it can contain a path _copyfile_single_value is run
153160
# to move all files and directories to the workflow directory
@@ -375,38 +382,6 @@ def get_open_loop():
375382
return loop
376383

377384

378-
# def output_from_inputfields(interface: "Interface"):
379-
# """
380-
# Collect values from output from input fields.
381-
# If names_only is False, the output_spec is updated,
382-
# if names_only is True only the names are returned
383-
384-
# Parameters
385-
# ----------
386-
# output_spec :
387-
# TODO
388-
# input_spec :
389-
# TODO
390-
391-
# """
392-
# current_output_spec_names = [f.name for f in attrs.fields(interface.Outputs)]
393-
# new_fields = []
394-
# for fld in attrs.fields(interface):
395-
# if "output_file_template" in fld.metadata:
396-
# if "output_field_name" in fld.metadata:
397-
# field_name = fld.metadata["output_field_name"]
398-
# else:
399-
# field_name = fld.name
400-
# # not adding if the field already in the output_spec
401-
# if field_name not in current_output_spec_names:
402-
# # TODO: should probably remove some of the keys
403-
# new_fields.append(
404-
# (field_name, attrs.field(type=File, metadata=fld.metadata))
405-
# )
406-
# output_spec.fields += new_fields
407-
# return output_spec
408-
409-
410385
def get_available_cpus():
411386
"""
412387
Return the number of CPUs available to the current process or, if that is not
@@ -658,7 +633,7 @@ def is_lazy(obj):
658633
if is_lazy(obj):
659634
return True
660635

661-
for f in attr_fields(obj):
636+
for f in attrs_fields(obj):
662637
if isinstance(getattr(obj, f.name), LazyField):
663638
return True
664639
return False

pydra/engine/helpers_file.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextlib import contextmanager
1111
import attr
1212
from fileformats.core import FileSet
13-
from pydra.engine.helpers import is_lazy
13+
from pydra.engine.helpers import is_lazy, attrs_values
1414

1515

1616
logger = logging.getLogger("pydra")
@@ -105,7 +105,7 @@ def template_update(inputs, output_dir, state_ind=None, map_copyfiles=None):
105105
106106
"""
107107

108-
inputs_dict_st = attr.asdict(inputs, recurse=False)
108+
inputs_dict_st = attrs_values(inputs)
109109
if map_copyfiles is not None:
110110
inputs_dict_st.update(map_copyfiles)
111111

@@ -114,12 +114,12 @@ def template_update(inputs, output_dir, state_ind=None, map_copyfiles=None):
114114
k = k.split(".")[1]
115115
inputs_dict_st[k] = inputs_dict_st[k][v]
116116

117-
from .specs import attr_fields
117+
from .specs import attrs_fields
118118

119119
# Collect templated inputs for which all requirements are satisfied.
120120
fields_templ = [
121121
field
122-
for field in attr_fields(inputs)
122+
for field in attrs_fields(inputs)
123123
if field.metadata.get("output_file_template")
124124
and getattr(inputs, field.name) is not False
125125
and all(
@@ -155,7 +155,7 @@ def template_update_single(
155155
from pydra.engine.specs import OUTPUT_TEMPLATE_TYPES
156156

157157
if inputs_dict_st is None:
158-
inputs_dict_st = attr.asdict(inputs, recurse=False)
158+
inputs_dict_st = attrs_values(inputs)
159159

160160
if spec_type == "input":
161161
inp_val_set = inputs_dict_st[field.name]

pydra/engine/helpers_state.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""Additional functions used mostly by the State class."""
22

3-
import attr
43
import itertools
54
from copy import deepcopy
65
import logging
76
import typing as ty
8-
from .helpers import ensure_list
7+
from .helpers import ensure_list, attrs_values
98

109
logger = logging.getLogger("pydra")
1110

@@ -622,9 +621,7 @@ def map_splits(split_iter, inputs, cont_dim=None):
622621
def inputs_types_to_dict(name, inputs):
623622
"""Convert type.Inputs to dictionary."""
624623
# dj: any better option?
625-
input_names = [
626-
field for field in attr.asdict(inputs, recurse=False) if field != "_func"
627-
]
624+
input_names = [field for field in attrs_values(inputs) if field != "_func"]
628625
inputs_dict = {}
629626
for field in input_names:
630627
inputs_dict[f"{name}.{field}"] = getattr(inputs, field)

0 commit comments

Comments
 (0)