@@ -62,6 +62,13 @@ def convert_default_value(value: ty.Any, self_: "Field") -> ty.Any:
6262 return TypeParser [self_ .type ](self_ .type , label = self_ .name )(value )
6363
6464
65+ def allowed_values_converter (value : ty .Iterable [str ] | None ) -> list [str ] | None :
66+ """Ensure the allowed_values field is a list of strings or None"""
67+ if value is None :
68+ return None
69+ return list (value )
70+
71+
6572@attrs .define
6673class Requirement :
6774 """Define a requirement for a task input field
@@ -76,14 +83,19 @@ class Requirement:
7683 """
7784
7885 name : str
79- allowed_values : list [str ] = attrs .field (factory = list , converter = list )
86+ allowed_values : list [str ] | None = attrs .field (
87+ default = None , converter = allowed_values_converter
88+ )
8089
8190 def satisfied (self , inputs : "TaskDef" ) -> bool :
8291 """Check if the requirement is satisfied by the inputs"""
8392 value = getattr (inputs , self .name )
84- if value is attrs .NOTHING :
93+ field = {f .name : f for f in list_fields (inputs )}[self .name ]
94+ if value is attrs .NOTHING or field .type is bool and value is False :
8595 return False
86- return not self .allowed_values or value in self .allowed_values
96+ if self .allowed_values is None :
97+ return True
98+ return value in self .allowed_values
8799
88100 @classmethod
89101 def parse (cls , value : ty .Any ) -> Self :
@@ -350,8 +362,8 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:
350362
351363 if not issubclass (klass , spec_type ):
352364 raise ValueError (
353- f"The canonical form of { spec_type .__module__ .split ('.' )[- 1 ]} task definitions, "
354- f"{ klass } , must inherit from { spec_type } "
365+ f"When using the canonical form for { spec_type .__module__ .split ('.' )[- 1 ]} "
366+ f"tasks, { klass } must inherit from { spec_type } "
355367 )
356368
357369 inputs = get_fields (klass , arg_type , auto_attribs , input_helps )
@@ -364,8 +376,8 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:
364376 ) from None
365377 if not issubclass (outputs_klass , outputs_type ):
366378 raise ValueError (
367- f"The canonical form of { spec_type .__module__ .split ('.' )[- 1 ]} task definitions, "
368- f"{ klass } , must inherit from { spec_type } "
379+ f"When using the canonical form for { outputs_type .__module__ .split ('.' )[- 1 ]} "
380+ f"task outputs { outputs_klass } , you must inherit from { outputs_type } "
369381 )
370382
371383 output_helps , _ = parse_doc_string (outputs_klass .__doc__ )
@@ -416,10 +428,12 @@ def make_task_def(
416428
417429 spec_type ._check_arg_refs (inputs , outputs )
418430
431+ # Check that the field attributes are valid after all fields have been set
432+ # (especially the type)
419433 for inpt in inputs .values ():
420- set_none_default_if_optional (inpt )
421- for outpt in inputs .values ():
422- set_none_default_if_optional (outpt )
434+ attrs . validate (inpt )
435+ for outpt in outputs .values ():
436+ attrs . validate (outpt )
423437
424438 if name is None and klass is not None :
425439 name = klass .__name__
@@ -459,10 +473,10 @@ def make_task_def(
459473 if getattr (arg , "path_template" , False ):
460474 if is_optional (arg .type ):
461475 field_type = Path | bool | None
462- # Will default to None and not be inserted into the command
476+ attrs_kwargs = { " default" : None }
463477 else :
464478 field_type = Path | bool
465- attrs_kwargs = {"default" : True }
479+ attrs_kwargs = {"default" : True } # use the template by default
466480 elif is_optional (arg .type ):
467481 field_type = Path | None
468482 else :
@@ -988,12 +1002,10 @@ def check_explicit_fields_are_none(klass, inputs, outputs):
9881002
9891003def _get_attrs_kwargs (field : Field ) -> dict [str , ty .Any ]:
9901004 kwargs = {}
991- if not hasattr (field , "default" ):
992- kwargs ["factory" ] = nothing_factory
993- elif field .default is not NO_DEFAULT :
1005+ if field .default is not NO_DEFAULT :
9941006 kwargs ["default" ] = field .default
995- elif is_optional (field .type ):
996- kwargs ["default" ] = None
1007+ # elif is_optional(field.type):
1008+ # kwargs["default"] = None
9971009 else :
9981010 kwargs ["factory" ] = nothing_factory
9991011 if field .hash_eq :
@@ -1005,9 +1017,9 @@ def nothing_factory():
10051017 return attrs .NOTHING
10061018
10071019
1008- def set_none_default_if_optional (field : Field ) -> None :
1009- if is_optional (field .type ) and field .default is NO_DEFAULT :
1010- field .default = None
1020+ # def set_none_default_if_optional(field: Field) -> None:
1021+ # if is_optional(field.type) and field.default is NO_DEFAULT:
1022+ # field.default = None
10111023
10121024
10131025white_space_re = re .compile (r"\s+" )
0 commit comments