Skip to content

Commit 7f1b259

Browse files
committed
fixed up issue with optional xor
1 parent 4a77081 commit 7f1b259

File tree

10 files changed

+48
-30
lines changed

10 files changed

+48
-30
lines changed

pydra/design/base.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,14 @@ class Arg(Field):
244244
the default value for the field, by default it is NO_DEFAULT
245245
help: str
246246
A short description of the input field.
247-
allowed_values: list, optional
248-
List of allowed values for the field.
249247
requires: list, optional
250248
Names of the inputs that are required together with the field.
251-
xor: list[str], optional
252-
Names of the inputs that are mutually exclusive with the field.
249+
allowed_values: Sequence, optional
250+
List of allowed values for the field.
251+
xor: Sequence[str | None], optional
252+
Names of args that are exclusive mutually exclusive, which must include
253+
the name of the current field. If this list includes None, then none of the
254+
fields need to be set.
253255
copy_mode: File.CopyMode, optional
254256
The mode of copying the file, by default it is File.CopyMode.any
255257
copy_collation: File.CopyCollation, optional
@@ -263,15 +265,20 @@ class Arg(Field):
263265
it is False
264266
"""
265267

266-
allowed_values: tuple = attrs.field(default=(), converter=tuple)
267-
xor: tuple[str] = attrs.field(default=(), converter=tuple)
268+
allowed_values: frozenset = attrs.field(default=(), converter=frozenset)
269+
xor: frozenset[str | None] = attrs.field(default=(), converter=frozenset)
268270
copy_mode: File.CopyMode = File.CopyMode.any
269271
copy_collation: File.CopyCollation = File.CopyCollation.any
270272
copy_ext_decomp: File.ExtensionDecomposition = File.ExtensionDecomposition.single
271273
readonly: bool = False
272274

273275
@xor.validator
274276
def _xor_validator(self, _, value):
277+
for v in value:
278+
if not isinstance(v, (str, type(None))):
279+
raise ValueError(
280+
f"xor values must be strings or None, not {v} ({self!r})"
281+
)
275282
if value and self.type not in (ty.Any, bool) and not is_optional(self.type):
276283
raise ValueError(
277284
f"Fields that have 'xor' must be of boolean or optional type, "

pydra/design/boutiques.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ class arg(shell.arg):
2828
List of allowed values for the field.
2929
requires: list, optional
3030
Names of the inputs that are required together with the field.
31-
xor: list[str], optional
32-
Names of the inputs that are mutually exclusive with the field.
31+
xor: list[str | None], optional
32+
Names of args that are exclusive mutually exclusive, which must include
33+
the name of the current field. If this list includes None, then none of the
34+
fields need to be set.
3335
copy_mode: File.CopyMode, optional
3436
The mode of copying the file, by default it is File.CopyMode.any
3537
copy_collation: File.CopyCollation, optional

pydra/design/python.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ class arg(Arg):
3333
List of allowed values for the field.
3434
requires: list, optional
3535
Names of the inputs that are required together with the field.
36-
xor: list, optional
37-
Names of the inputs that are mutually exclusive with the field.
36+
xor: list[str | None], optional
37+
Names of args that are exclusive mutually exclusive, which must include
38+
the name of the current field. If this list includes None, then none of the
39+
fields need to be set.
3840
copy_mode: File.CopyMode, optional
3941
The mode of copying the file, by default it is File.CopyMode.any
4042
copy_collation: File.CopyCollation, optional

pydra/design/shell.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ class arg(Arg):
5656
List of allowed values for the field.
5757
requires: list, optional
5858
List of field names that are required together with the field.
59-
xor: list, optional
60-
List of field names that are mutually exclusive with the field.
59+
xor: list[str | None], optional
60+
Names of args that are exclusive mutually exclusive, which must include
61+
the name of the current field. If this list includes None, then none of the
62+
fields need to be set.
6163
copy_mode: File.CopyMode, optional
6264
The mode of copying the file, by default it is File.CopyMode.any
6365
copy_collation: File.CopyCollation, optional
@@ -192,8 +194,10 @@ class outarg(arg, Out):
192194
List of allowed values for the field.
193195
requires: list, optional
194196
List of field names that are required together with the field.
195-
xor: list, optional
196-
List of field names that are mutually exclusive with the field.
197+
xor: list[str | None], optional
198+
Names of args that are exclusive mutually exclusive, which must include
199+
the name of the current field. If this list includes None, then none of the
200+
fields need to be set.
197201
copy_mode: File.CopyMode, optional
198202
The mode of copying the file, by default it is File.CopyMode.any
199203
copy_collation: File.CopyCollation, optional

pydra/design/tests/test_shell.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -502,14 +502,14 @@ class Ls(ShellDef["Ls.Outputs"]):
502502
argstr="-T",
503503
default=False,
504504
requires=["long_format"],
505-
xor=["date_format_str"],
505+
xor=["complete_date", "date_format_str", None],
506506
)
507507
date_format_str: str | None = shell.arg(
508508
help="format string for ",
509509
argstr="-D",
510510
default=None,
511511
requires=["long_format"],
512-
xor=["complete_date"],
512+
xor=["complete_date", "date_format_str", None],
513513
)
514514

515515
@shell.outputs
@@ -557,15 +557,15 @@ class Outputs(ShellOutputs):
557557
argstr="-T",
558558
default=False,
559559
requires=["long_format"],
560-
xor=["date_format_str"],
560+
xor=["complete_date", "date_format_str", None],
561561
),
562562
"date_format_str": shell.arg(
563563
type=str | None,
564564
help="format string for ",
565565
default=None,
566566
argstr="-D",
567567
requires=["long_format"],
568-
xor=["complete_date"],
568+
xor=["date_format_str", "complete_date", None],
569569
),
570570
},
571571
outputs={

pydra/design/workflow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ class arg(Arg):
3737
List of allowed values for the field.
3838
requires: list, optional
3939
Names of the inputs that are required together with the field.
40-
xor: list, optional
41-
Names of the inputs that are mutually exclusive with the field.
40+
xor: list[str | None], optional
41+
Names of args that are exclusive mutually exclusive, which must include
42+
the name of the current field. If this list includes None, then none of the
43+
fields need to be set.
4244
copy_mode: File.CopyMode, optional
4345
The mode of copying the file, by default it is File.CopyMode.any
4446
copy_collation: File.CopyCollation, optional

pydra/engine/specs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,13 +503,13 @@ def _rule_violations(self) -> list[str]:
503503

504504
# Collect alternative fields associated with this field.
505505
if field.xor:
506-
mutually_exclusive = {name: self[name] for name in field.xor}
506+
mutually_exclusive = {name: self[name] for name in field.xor if name}
507507
are_set = [f"{n}={v!r}" for n, v in mutually_exclusive.items() if v]
508508
if len(are_set) > 1:
509509
errors.append(
510510
f"Mutually exclusive fields ({', '.join(are_set)}) are set together"
511511
)
512-
elif not are_set:
512+
elif not are_set and None not in field.xor:
513513
errors.append(
514514
"At least one of the mutually exclusive fields should be set: "
515515
+ ", ".join(f"{n}={v!r}" for n, v in mutually_exclusive.items())
@@ -568,7 +568,7 @@ def _check_arg_refs(cls, inputs: list[Arg], outputs: list[Out]) -> None:
568568
f"of {field} " + str(list(unrecognised))
569569
)
570570
for inpt in inputs.values():
571-
if unrecognised := set(inpt.xor) - input_names:
571+
if unrecognised := inpt.xor - (input_names | {None}):
572572
raise ValueError(
573573
"'Unrecognised' field names in referenced in the xor "
574574
f"of {inpt} " + str(list(unrecognised))

pydra/engine/tests/test_helpers_file.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,14 @@ def test_template_formatting(tmp_path: Path):
393393
field.name = "grad"
394394
field.argstr = "--grad"
395395
field.path_template = ("{in_file}.bvec", "{in_file}.bval")
396-
inputs = Mock()
397-
inputs_dict = {"in_file": "/a/b/c/file.txt", "grad": True}
396+
field.keep_extension = False
397+
definition = Mock()
398+
values = {"in_file": Path("/a/b/c/file.txt"), "grad": True}
398399

399400
assert template_update_single(
400401
field,
401-
inputs,
402-
input_values=inputs_dict,
402+
definition,
403+
values=values,
403404
output_dir=tmp_path,
404405
spec_type="input",
405406
) == [tmp_path / "file.bvec", tmp_path / "file.bval"]

pydra/engine/tests/test_shelltask_inputspec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,7 @@ class Outputs(ShellOutputs):
11251125
help="inpA",
11261126
argstr="",
11271127
)
1128-
inpStr: str = shell.arg(
1128+
inpStr: Path = shell.arg(
11291129
position=2,
11301130
help="inp str with extension",
11311131
argstr="-i",

pydra/engine/tests/test_task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,11 +1449,11 @@ class Defn(ShellDef["Defn.Outputs"]):
14491449
class Outputs(ShellOutputs):
14501450
pass
14511451

1452-
inputs = Defn(a1_field="1", b2_field=2.0, c3_field={"c": "3"}, d4_field=["4"])
1452+
values = dict(a1_field="1", b2_field=2.0, c3_field={"c": "3"}, d4_field=["4"])
14531453
assert (
14541454
argstr_formatting(
14551455
"{a1_field} {b2_field:02f} -test {c3_field[c]} -me {d4_field[0]}",
1456-
inputs,
1456+
values,
14571457
)
14581458
== "1 2.000000 -test 3 -me 4"
14591459
)

0 commit comments

Comments
 (0)