Skip to content

Commit 9df1191

Browse files
committed
moved xor into *.define decorators from *.arg fields
1 parent a52b748 commit 9df1191

File tree

9 files changed

+90
-90
lines changed

9 files changed

+90
-90
lines changed

pydra/design/base.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,6 @@ class Arg(Field):
248248
Names of the inputs that are required together with the field.
249249
allowed_values: Sequence, optional
250250
List of allowed values for the field.
251-
xor: Sequence[str | None], optional
252-
Names of args that are exclusive mutually exclusive, which must include
253-
the name of the current field. If this list includes None, then none of the
254-
fields need to be set.
255251
copy_mode: File.CopyMode, optional
256252
The mode of copying the file, by default it is File.CopyMode.any
257253
copy_collation: File.CopyCollation, optional
@@ -266,25 +262,11 @@ class Arg(Field):
266262
"""
267263

268264
allowed_values: frozenset = attrs.field(default=(), converter=frozenset)
269-
xor: frozenset[str | None] = attrs.field(default=(), converter=frozenset)
270265
copy_mode: File.CopyMode = File.CopyMode.any
271266
copy_collation: File.CopyCollation = File.CopyCollation.any
272267
copy_ext_decomp: File.ExtensionDecomposition = File.ExtensionDecomposition.single
273268
readonly: bool = False
274269

275-
@xor.validator
276-
def _xor_validator(self, _, value):
277-
for v in value:
278-
if not isinstance(v, (str, type(None))):
279-
raise ValueError(
280-
f"xor values must be strings or None, not {v} ({self!r})"
281-
)
282-
if value and self.type not in (ty.Any, bool) and not is_optional(self.type):
283-
raise ValueError(
284-
f"Fields that have 'xor' must be of boolean or optional type, "
285-
f"not type {self.type} ({self!r})"
286-
)
287-
288270

289271
@attrs.define(kw_only=True, slots=False)
290272
class Out(Field):
@@ -418,6 +400,7 @@ def make_task_def(
418400
name: str | None = None,
419401
bases: ty.Sequence[type] = (),
420402
outputs_bases: ty.Sequence[type] = (),
403+
xor: ty.Sequence[str | None] | ty.Sequence[ty.Sequence[str | None]] = (),
421404
):
422405
"""Create a task definition class and its outputs definition class from the
423406
input and output fields provided to the decorator/function.
@@ -442,14 +425,26 @@ def make_task_def(
442425
The base classes for the task definition class, by default ()
443426
outputs_bases : ty.Sequence[type], optional
444427
The base classes for the outputs definition class, by default ()
428+
xor: Sequence[str | None] | Sequence[Sequence[str | None]], optional
429+
Names of args that are exclusive mutually exclusive, which must include
430+
the name of the current field. If this list includes None, then none of the
431+
fields need to be set.
445432
446433
Returns
447434
-------
448435
klass : type
449436
The class created using the attrs package
450437
"""
451438

452-
spec_type._check_arg_refs(inputs, outputs)
439+
# Convert a single xor set into a set of xor sets
440+
if not xor:
441+
xor = frozenset()
442+
elif all(isinstance(x, str) or x is None for x in xor):
443+
xor = frozenset([frozenset(xor)])
444+
else:
445+
xor = frozenset(frozenset(x) for x in xor)
446+
447+
spec_type._check_arg_refs(inputs, outputs, xor)
453448

454449
# Check that the field attributes are valid after all fields have been set
455450
# (especially the type)
@@ -521,6 +516,8 @@ def make_task_def(
521516
**attrs_kwargs,
522517
),
523518
)
519+
# Store the xor sets for the class
520+
klass._xor = xor
524521
klass.__annotations__[arg.name] = field_type
525522

526523
# Create class using attrs package, will create attributes for all columns and

pydra/design/boutiques.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ class arg(shell.arg):
2828
List of allowed values for the field.
2929
requires: list, optional
3030
Names of the inputs that are required together with the field.
31-
xor: list[str | None], optional
32-
Names of args that are exclusive mutually exclusive, which must include
33-
the name of the current field. If this list includes None, then none of the
34-
fields need to be set.
3531
copy_mode: File.CopyMode, optional
3632
The mode of copying the file, by default it is File.CopyMode.any
3733
copy_collation: File.CopyCollation, optional

pydra/design/python.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ class arg(Arg):
3333
List of allowed values for the field.
3434
requires: list, optional
3535
Names of the inputs that are required together with the field.
36-
xor: list[str | None], optional
37-
Names of args that are exclusive mutually exclusive, which must include
38-
the name of the current field. If this list includes None, then none of the
39-
fields need to be set.
4036
copy_mode: File.CopyMode, optional
4137
The mode of copying the file, by default it is File.CopyMode.any
4238
copy_collation: File.CopyCollation, optional
@@ -106,6 +102,7 @@ def define(
106102
bases: ty.Sequence[type] = (),
107103
outputs_bases: ty.Sequence[type] = (),
108104
auto_attribs: bool = True,
105+
xor: ty.Sequence[str | None] | ty.Sequence[ty.Sequence[str | None]] = (),
109106
) -> "PythonDef":
110107
"""
111108
Create an interface for a function or a class.
@@ -120,6 +117,10 @@ def define(
120117
The outputs of the function or class.
121118
auto_attribs : bool
122119
Whether to use auto_attribs mode when creating the class.
120+
xor: Sequence[str | None] | Sequence[Sequence[str | None]], optional
121+
Names of args that are exclusive mutually exclusive, which must include
122+
the name of the current field. If this list includes None, then none of the
123+
fields need to be set.
123124
124125
Returns
125126
-------
@@ -172,6 +173,7 @@ def make(wrapped: ty.Callable | type) -> PythonDef:
172173
klass=klass,
173174
bases=bases,
174175
outputs_bases=outputs_bases,
176+
xor=xor,
175177
)
176178

177179
return defn

pydra/design/shell.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@ class arg(Arg):
5656
List of allowed values for the field.
5757
requires: list, optional
5858
List of field names that are required together with the field.
59-
xor: list[str | None], optional
60-
Names of args that are exclusive mutually exclusive, which must include
61-
the name of the current field. If this list includes None, then none of the
62-
fields need to be set.
6359
copy_mode: File.CopyMode, optional
6460
The mode of copying the file, by default it is File.CopyMode.any
6561
copy_collation: File.CopyCollation, optional
@@ -194,10 +190,6 @@ class outarg(arg, Out):
194190
List of allowed values for the field.
195191
requires: list, optional
196192
List of field names that are required together with the field.
197-
xor: list[str | None], optional
198-
Names of args that are exclusive mutually exclusive, which must include
199-
the name of the current field. If this list includes None, then none of the
200-
fields need to be set.
201193
copy_mode: File.CopyMode, optional
202194
The mode of copying the file, by default it is File.CopyMode.any
203195
copy_collation: File.CopyCollation, optional
@@ -291,6 +283,7 @@ def define(
291283
outputs_bases: ty.Sequence[type] = (),
292284
auto_attribs: bool = True,
293285
name: str | None = None,
286+
xor: ty.Sequence[str | None] | ty.Sequence[ty.Sequence[str | None]] = (),
294287
) -> "ShellDef":
295288
"""Create a task definition for a shell command. Can be used either as a decorator on
296289
the "canonical" dataclass-form of a task definition or as a function that takes a
@@ -337,6 +330,10 @@ def define(
337330
as they appear in the template
338331
name: str | None
339332
The name of the returned class
333+
xor: Sequence[str | None] | Sequence[Sequence[str | None]], optional
334+
Names of args that are exclusive mutually exclusive, which must include
335+
the name of the current field. If this list includes None, then none of the
336+
fields need to be set.
340337
341338
Returns
342339
-------
@@ -446,6 +443,7 @@ def make(
446443
klass=klass,
447444
bases=bases,
448445
outputs_bases=outputs_bases,
446+
xor=xor,
449447
)
450448
return defn
451449

pydra/design/tests/test_shell.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def test_interface_template_with_type_overrides():
469469
def Ls(request):
470470
if request.param == "static":
471471

472-
@shell.define
472+
@shell.define(xor=["complete_date", "date_format_str", None])
473473
class Ls(ShellDef["Ls.Outputs"]):
474474
executable = "ls"
475475

@@ -502,14 +502,12 @@ class Ls(ShellDef["Ls.Outputs"]):
502502
argstr="-T",
503503
default=False,
504504
requires=["long_format"],
505-
xor=["complete_date", "date_format_str", None],
506505
)
507506
date_format_str: str | None = shell.arg(
508507
help="format string for ",
509508
argstr="-D",
510509
default=None,
511510
requires=["long_format"],
512-
xor=["complete_date", "date_format_str", None],
513511
)
514512

515513
@shell.outputs

pydra/design/workflow.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@ class arg(Arg):
3737
List of allowed values for the field.
3838
requires: list, optional
3939
Names of the inputs that are required together with the field.
40-
xor: list[str | None], optional
41-
Names of args that are exclusive mutually exclusive, which must include
42-
the name of the current field. If this list includes None, then none of the
43-
fields need to be set.
4440
copy_mode: File.CopyMode, optional
4541
The mode of copying the file, by default it is File.CopyMode.any
4642
copy_collation: File.CopyCollation, optional
@@ -111,6 +107,7 @@ def define(
111107
outputs_bases: ty.Sequence[type] = (),
112108
lazy: list[str] | None = None,
113109
auto_attribs: bool = True,
110+
xor: ty.Sequence[str | None] | ty.Sequence[ty.Sequence[str | None]] = (),
114111
) -> "WorkflowDef":
115112
"""
116113
Create an interface for a function or a class. Can be used either as a decorator on
@@ -126,6 +123,10 @@ def define(
126123
The outputs of the function or class.
127124
auto_attribs : bool
128125
Whether to use auto_attribs mode when creating the class.
126+
xor: Sequence[str | None] | Sequence[Sequence[str | None]], optional
127+
Names of args that are exclusive mutually exclusive, which must include
128+
the name of the current field. If this list includes None, then none of the
129+
fields need to be set.
129130
130131
Returns
131132
-------
@@ -183,6 +184,7 @@ def make(wrapped: ty.Callable | type) -> TaskDef:
183184
klass=klass,
184185
bases=bases,
185186
outputs_bases=outputs_bases,
187+
xor=xor,
186188
)
187189

188190
return defn

pydra/engine/specs.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ def reset(self):
173173
class TaskDef(ty.Generic[OutputsType]):
174174
"""Base class for all task definitions"""
175175

176+
# Class attributes
177+
_xor: frozenset[frozenset[str | None]] = (
178+
frozenset()
179+
) # overwritten in derived classes
180+
176181
# The following fields are used to store split/combine state information
177182
_splitter = attrs.field(default=None, init=False, repr=False)
178183
_combiner = attrs.field(default=None, init=False, repr=False)
@@ -501,20 +506,6 @@ def _rule_violations(self) -> list[str]:
501506
):
502507
errors.append(f"Mandatory field {field.name!r} is not set")
503508

504-
# Collect alternative fields associated with this field.
505-
if field.xor:
506-
mutually_exclusive = {name: self[name] for name in field.xor if name}
507-
are_set = [f"{n}={v!r}" for n, v in mutually_exclusive.items() if v]
508-
if len(are_set) > 1:
509-
errors.append(
510-
f"Mutually exclusive fields ({', '.join(are_set)}) are set together"
511-
)
512-
elif not are_set and None not in field.xor:
513-
errors.append(
514-
"At least one of the mutually exclusive fields should be set: "
515-
+ ", ".join(f"{n}={v!r}" for n, v in mutually_exclusive.items())
516-
)
517-
518509
# Raise error if any required field is unset.
519510
if (
520511
not (
@@ -538,6 +529,19 @@ def _rule_violations(self) -> list[str]:
538529
errors.append(
539530
f"{field.name!r} requires{qualification} {[str(r) for r in field.requires]}"
540531
)
532+
# Collect alternative fields associated with this field.
533+
for xor_set in self._xor:
534+
mutually_exclusive = {name: self[name] for name in xor_set if name}
535+
are_set = [f"{n}={v!r}" for n, v in mutually_exclusive.items() if v]
536+
if len(are_set) > 1:
537+
errors.append(
538+
f"Mutually exclusive fields ({', '.join(are_set)}) are set together"
539+
)
540+
elif not are_set and None not in xor_set:
541+
errors.append(
542+
"At least one of the mutually exclusive fields should be set: "
543+
+ ", ".join(f"{n}={v!r}" for n, v in mutually_exclusive.items())
544+
)
541545
return errors
542546

543547
def _check_rules(self):
@@ -552,7 +556,12 @@ def _check_rules(self):
552556
)
553557

554558
@classmethod
555-
def _check_arg_refs(cls, inputs: list[Arg], outputs: list[Out]) -> None:
559+
def _check_arg_refs(
560+
cls,
561+
inputs: list[Arg],
562+
outputs: list[Out],
563+
xor: frozenset[frozenset[str | None]],
564+
) -> None:
556565
"""
557566
Checks if all fields referenced in requirements and xor are present in the inputs
558567
are valid field names
@@ -567,12 +576,22 @@ def _check_arg_refs(cls, inputs: list[Arg], outputs: list[Out]) -> None:
567576
"'Unrecognised' field names in referenced in the requirements "
568577
f"of {field} " + str(list(unrecognised))
569578
)
570-
for inpt in inputs.values():
571-
if unrecognised := inpt.xor - (input_names | {None}):
579+
580+
for xor_set in xor:
581+
if unrecognised := xor_set - (input_names | {None}):
572582
raise ValueError(
573-
"'Unrecognised' field names in referenced in the xor "
574-
f"of {inpt} " + str(list(unrecognised))
583+
f"'Unrecognised' field names in referenced in the xor {xor_set} "
584+
+ str(list(unrecognised))
575585
)
586+
for field_name in xor_set:
587+
if field_name is None: # i.e. none of the fields being set is valid
588+
continue
589+
type_ = inputs[field_name].type
590+
if type_ not in (ty.Any, bool) and not is_optional(type_):
591+
raise ValueError(
592+
f"Fields included in a 'xor' ({field.name!r}) must be of boolean "
593+
f"or optional types, not type {type_}"
594+
)
576595

577596
def _check_resolved(self):
578597
"""Checks that all the fields in the definition have been resolved"""
@@ -762,6 +781,8 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
762781
for name, lazy_field in attrs_values(workflow.outputs).items():
763782
try:
764783
val_out = lazy_field._get_value(workflow=workflow, graph=exec_graph)
784+
if isinstance(val_out, StateArray):
785+
val_out = list(val_out) # implicitly combine state arrays
765786
output_wf[name] = val_out
766787
except (ValueError, AttributeError):
767788
output_wf[name] = None

0 commit comments

Comments
 (0)