@@ -173,6 +173,11 @@ def reset(self):
173173class TaskDef (ty .Generic [OutputsType ]):
174174 """Base class for all task definitions"""
175175
176+ # Class attributes
177+ _xor : frozenset [frozenset [str | None ]] = (
178+ frozenset ()
179+ ) # overwritten in derived classes
180+
176181 # The following fields are used to store split/combine state information
177182 _splitter = attrs .field (default = None , init = False , repr = False )
178183 _combiner = attrs .field (default = None , init = False , repr = False )
@@ -501,20 +506,6 @@ def _rule_violations(self) -> list[str]:
501506 ):
502507 errors .append (f"Mandatory field { field .name !r} is not set" )
503508
504- # Collect alternative fields associated with this field.
505- if field .xor :
506- mutually_exclusive = {name : self [name ] for name in field .xor if name }
507- are_set = [f"{ n } ={ v !r} " for n , v in mutually_exclusive .items () if v ]
508- if len (are_set ) > 1 :
509- errors .append (
510- f"Mutually exclusive fields ({ ', ' .join (are_set )} ) are set together"
511- )
512- elif not are_set and None not in field .xor :
513- errors .append (
514- "At least one of the mutually exclusive fields should be set: "
515- + ", " .join (f"{ n } ={ v !r} " for n , v in mutually_exclusive .items ())
516- )
517-
518509 # Raise error if any required field is unset.
519510 if (
520511 not (
@@ -538,6 +529,19 @@ def _rule_violations(self) -> list[str]:
538529 errors .append (
539530 f"{ field .name !r} requires{ qualification } { [str (r ) for r in field .requires ]} "
540531 )
532+ # Collect alternative fields associated with this field.
533+ for xor_set in self ._xor :
534+ mutually_exclusive = {name : self [name ] for name in xor_set if name }
535+ are_set = [f"{ n } ={ v !r} " for n , v in mutually_exclusive .items () if v ]
536+ if len (are_set ) > 1 :
537+ errors .append (
538+ f"Mutually exclusive fields ({ ', ' .join (are_set )} ) are set together"
539+ )
540+ elif not are_set and None not in xor_set :
541+ errors .append (
542+ "At least one of the mutually exclusive fields should be set: "
543+ + ", " .join (f"{ n } ={ v !r} " for n , v in mutually_exclusive .items ())
544+ )
541545 return errors
542546
543547 def _check_rules (self ):
@@ -552,7 +556,12 @@ def _check_rules(self):
552556 )
553557
554558 @classmethod
555- def _check_arg_refs (cls , inputs : list [Arg ], outputs : list [Out ]) -> None :
559+ def _check_arg_refs (
560+ cls ,
561+ inputs : list [Arg ],
562+ outputs : list [Out ],
563+ xor : frozenset [frozenset [str | None ]],
564+ ) -> None :
556565 """
557566 Checks if all fields referenced in requirements and xor are present in the inputs
558567 are valid field names
@@ -567,12 +576,22 @@ def _check_arg_refs(cls, inputs: list[Arg], outputs: list[Out]) -> None:
567576 "'Unrecognised' field names in referenced in the requirements "
568577 f"of { field } " + str (list (unrecognised ))
569578 )
570- for inpt in inputs .values ():
571- if unrecognised := inpt .xor - (input_names | {None }):
579+
580+ for xor_set in xor :
581+ if unrecognised := xor_set - (input_names | {None }):
572582 raise ValueError (
573- "'Unrecognised' field names in referenced in the xor "
574- f"of { inpt } " + str (list (unrecognised ))
583+ f "'Unrecognised' field names in referenced in the xor { xor_set } "
584+ + str (list (unrecognised ))
575585 )
586+ for field_name in xor_set :
587+ if field_name is None : # i.e. none of the fields being set is valid
588+ continue
589+ type_ = inputs [field_name ].type
590+ if type_ not in (ty .Any , bool ) and not is_optional (type_ ):
591+ raise ValueError (
592+ f"Fields included in a 'xor' ({ field .name !r} ) must be of boolean "
593+ f"or optional types, not type { type_ } "
594+ )
576595
577596 def _check_resolved (self ):
578597 """Checks that all the fields in the definition have been resolved"""
@@ -762,6 +781,8 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
762781 for name , lazy_field in attrs_values (workflow .outputs ).items ():
763782 try :
764783 val_out = lazy_field ._get_value (workflow = workflow , graph = exec_graph )
784+ if isinstance (val_out , StateArray ):
785+ val_out = list (val_out ) # implicitly combine state arrays
765786 output_wf [name ] = val_out
766787 except (ValueError , AttributeError ):
767788 output_wf [name ] = None
0 commit comments