Skip to content

Commit 2192c45

Browse files
committed
relaxed task definition validation to accept use cases found in ANTs, FSL, Freesurfer, etc... task packages
1 parent 275127e commit 2192c45

File tree

7 files changed

+56
-21
lines changed

7 files changed

+56
-21
lines changed

pydra/compose/base/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
check_explicit_fields_are_none,
77
extract_fields_from_class,
88
is_set,
9+
sanitize_xor,
910
)
1011
from .task import Task, Outputs
1112
from .builder import build_task_class

pydra/compose/base/builder.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
is_lazy,
2020
)
2121
from .field import Field, Arg, Out
22+
from .helpers import sanitize_xor
2223

2324

2425
def build_task_class(
@@ -65,15 +66,7 @@ def build_task_class(
6566
klass : type
6667
The class created using the attrs package
6768
"""
68-
69-
# Convert a single xor set into a set of xor sets
70-
if not xor:
71-
xor = frozenset()
72-
elif all(isinstance(x, str) or x is None for x in xor):
73-
xor = frozenset([frozenset(xor)])
74-
else:
75-
xor = frozenset(frozenset(x) for x in xor)
76-
69+
xor = sanitize_xor(xor)
7770
spec_type._check_arg_refs(inputs, outputs, xor)
7871

7972
# Check that the field attributes are valid after all fields have been set

pydra/compose/base/field.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from attrs.converters import default_if_none
66
from fileformats.core import to_mime
77
from fileformats.generic import File, FileSet
8-
from pydra.utils.typing import TypeParser, is_optional, is_type, is_union
8+
from pydra.utils.typing import TypeParser, is_optional, is_truthy_falsy, is_type, is_union
99
from pydra.utils.general import get_fields, wrap_text
1010
import attrs
1111

@@ -229,10 +229,10 @@ def mandatory(self):
229229

230230
@requires.validator
231231
def _requires_validator(self, _, value):
232-
if value and self.type not in (ty.Any, bool) and not is_optional(self.type):
232+
if value and not is_truthy_falsy(self.type):
233233
raise ValueError(
234-
f"Fields with requirements must be of optional type (i.e. in union "
235-
f"with None) or boolean, not type {self.type} ({self!r})"
234+
f"Fields with requirements must be of optional (i.e. in union "
235+
f"with None) or truthy/falsy type, not type {self.type} ({self!r})"
236236
)
237237

238238
def markdown_listing(

pydra/compose/base/helpers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,20 @@ def check_explicit_fields_are_none(klass, inputs, outputs):
379379
)
380380

381381

382+
def sanitize_xor(
383+
xor: ty.Sequence[str | None] | ty.Sequence[ty.Sequence[str | None]],
384+
) -> set[frozenset[str]]:
385+
"""Convert a list of xor sets into a set of frozensets"""
386+
# Convert a single xor set into a set of xor sets
387+
if not xor:
388+
xor = frozenset()
389+
elif all(isinstance(x, str) or x is None for x in xor):
390+
xor = frozenset([frozenset(xor)])
391+
else:
392+
xor = frozenset(frozenset(x) for x in xor)
393+
return xor
394+
395+
382396
def extract_fields_from_class(
383397
spec_type: type["Task"],
384398
outputs_type: type["Outputs"],

pydra/compose/base/task.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from copy import copy
55
from typing import Self
66
import attrs.validators
7-
from pydra.utils.typing import is_optional, is_fileset_or_union
7+
from pydra.utils.typing import is_optional, is_fileset_or_union, is_truthy_falsy
88
from pydra.utils.general import get_fields
99
from pydra.utils.typing import StateArray, is_lazy
1010
from pydra.utils.hash import hash_function
@@ -595,17 +595,17 @@ def _check_arg_refs(
595595
for xor_set in xor:
596596
if unrecognised := xor_set - (input_names | {None}):
597597
raise ValueError(
598-
f"'Unrecognised' field names in referenced in the xor {xor_set} "
598+
f"Unrecognised field names in referenced in the xor {xor_set}: "
599599
+ str(list(unrecognised))
600600
)
601601
for field_name in xor_set:
602602
if field_name is None: # i.e. none of the fields being set is valid
603603
continue
604604
type_ = inputs[field_name].type
605-
if type_ not in (ty.Any, bool) and not is_optional(type_):
605+
if not is_truthy_falsy(type_):
606606
raise ValueError(
607-
f"Fields included in a 'xor' ({field_name!r}) must be of boolean "
608-
f"or optional types, not type {type_}"
607+
f"Fields included in a 'xor' ({field_name!r}) must be an optional type or a"
608+
f"truthy/falsy type, not type {type_}"
609609
)
610610

611611
def _check_resolved(self):

pydra/compose/shell/builder.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
extract_fields_from_class,
2222
ensure_field_objects,
2323
build_task_class,
24+
sanitize_xor,
2425
NO_DEFAULT,
2526
)
2627
from pydra.utils.typing import (
@@ -208,7 +209,7 @@ def make(
208209
)
209210

210211
# Set positions for the remaining inputs that don't have an explicit position
211-
position_stack = remaining_positions(list(parsed_inputs.values()))
212+
position_stack = remaining_positions(list(parsed_inputs.values()), xor=xor)
212213
for inpt in parsed_inputs.values():
213214
if inpt.name == "append_args":
214215
continue
@@ -526,7 +527,10 @@ def from_type_str(type_str) -> type:
526527

527528

528529
def remaining_positions(
529-
args: list[Arg], num_args: int | None = None, start: int = 0
530+
args: list[Arg],
531+
num_args: int | None = None,
532+
start: int = 0,
533+
xor: set[frozenset[str]] | None = None,
530534
) -> ty.List[int]:
531535
"""Get the remaining positions for input fields
532536
@@ -536,6 +540,10 @@ def remaining_positions(
536540
The list of input fields
537541
num_args : int, optional
538542
The number of arguments, by default it is the length of the args
543+
start : int, optional
544+
The starting position, by default 0
545+
xor : set[frozenset[str]], optional
546+
A set of mutually exclusive fields, by default None
539547
540548
Returns
541549
-------
@@ -547,6 +555,7 @@ def remaining_positions(
547555
ValueError
548556
If multiple fields have the same position
549557
"""
558+
xor = sanitize_xor(xor)
550559
if num_args is None:
551560
num_args = len(args) - 1 # Subtract 1 for the 'append_args' field
552561
# Check for multiple positions
@@ -562,7 +571,7 @@ def remaining_positions(
562571
if multiple_positions := {
563572
k: [f"{a.name}({a.position})" for a in v]
564573
for k, v in positions.items()
565-
if len(v) > 1
574+
if len(v) > 1 and frozenset(a.name for a in v) not in xor
566575
}:
567576
raise ValueError(
568577
f"Multiple fields have the overlapping positions: {multiple_positions}"

pydra/utils/typing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,24 @@ def is_optional(type_: type) -> bool:
10701070
return False
10711071

10721072

1073+
def is_container(type_: type) -> bool:
1074+
"""Check if the type is a container, i.e. a list, tuple, or MultiOutputObj"""
1075+
origin = ty.get_origin(type_)
1076+
tp = origin if origin else type_
1077+
return inspect.isclass(tp) and issubclass(tp, ty.Container)
1078+
1079+
1080+
def is_truthy_falsy(type_: type) -> bool:
1081+
"""Check if the type is a truthy type, i.e. not None, bool, or typing.Any"""
1082+
return (
1083+
type_ in (ty.Any, bool, int, str)
1084+
or is_optional(type_)
1085+
or is_container(type_)
1086+
or hasattr(type_, "__bool__")
1087+
or hasattr(type_, "__len__")
1088+
)
1089+
1090+
10731091
def optional_type(type_: type) -> type:
10741092
"""Gets the non-None args of an optional type (i.e. a union with a None arg)"""
10751093
if is_optional(type_):

0 commit comments

Comments
 (0)