Skip to content

Commit a696e11

Browse files
authored
Merge pull request #320 from djarecka/enh/validator
adding a simple validator to the attrs class
2 parents 1d39820 + 892afc8 commit a696e11

File tree

10 files changed

+509
-66
lines changed

10 files changed

+509
-66
lines changed

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
attrs
1+
attrs >= 19.1.0
22
cloudpickle
33
filelock
44
git+https://github.com/AleksandarPetrov/napoleon.git@0dc3f28a309ad602be5f44a9049785a1026451b3#egg=sphinxcontrib-napoleon

pydra/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,13 @@ def check_latest_version():
3535

3636
if TaskBase._etelemetry_version_data is None:
3737
TaskBase._etelemetry_version_data = check_latest_version()
38+
39+
40+
# attr run_validators is set to False, but could be changed using use_validator
41+
import attr
42+
43+
attr.set_run_validators(False)
44+
45+
46+
def set_input_validator(flag=False):
47+
attr.set_run_validators(flag)

pydra/engine/helpers.py

Lines changed: 195 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
import re
1414
from time import strftime
1515
from traceback import format_exception
16+
import typing as ty
17+
import inspect
18+
import warnings
1619

1720

18-
from .specs import Runtime, File, Directory, attr_fields, Result
21+
from .specs import Runtime, File, Directory, attr_fields, Result, LazyField
1922
from .helpers_file import hash_file, hash_dir, copyfile, is_existing_file
2023

2124

@@ -234,8 +237,11 @@ def make_klass(spec):
234237
if len(item) == 2:
235238
if isinstance(item[1], attr._make._CountingAttr):
236239
newfields[item[0]] = item[1]
240+
newfields[item[0]].validator(custom_validator)
237241
else:
238-
newfields[item[0]] = attr.ib(type=item[1])
242+
newfields[item[0]] = attr.ib(
243+
type=item[1], validator=custom_validator
244+
)
239245
else:
240246
if (
241247
any([isinstance(ii, attr._make._CountingAttr) for ii in item])
@@ -251,17 +257,201 @@ def make_klass(spec):
251257
name, tp = item[:2]
252258
if isinstance(item[-1], dict) and "help_string" in item[-1]:
253259
mdata = item[-1]
254-
newfields[name] = attr.ib(type=tp, metadata=mdata)
260+
newfields[name] = attr.ib(
261+
type=tp, metadata=mdata, validator=custom_validator
262+
)
255263
else:
256264
dflt = item[-1]
257-
newfields[name] = attr.ib(type=tp, default=dflt)
265+
newfields[name] = attr.ib(
266+
type=tp, default=dflt, validator=custom_validator
267+
)
258268
elif len(item) == 4:
259269
name, tp, dflt, mdata = item
260-
newfields[name] = attr.ib(type=tp, default=dflt, metadata=mdata)
270+
newfields[name] = attr.ib(
271+
type=tp,
272+
default=dflt,
273+
metadata=mdata,
274+
validator=custom_validator,
275+
)
261276
fields = newfields
262277
return attr.make_class(spec.name, fields, bases=spec.bases, kw_only=True)
263278

264279

280+
def custom_validator(instance, attribute, value):
281+
"""simple custom validation
282+
take into account ty.Union, ty.List, ty.Dict (but only one level depth)
283+
adding an additional validator, if allowe_values provided
284+
"""
285+
validators = []
286+
tp_attr = attribute.type
287+
# a flag that could be changed to False, if the type is not recognized
288+
check_type = True
289+
if (
290+
value is attr.NOTHING
291+
or value is None
292+
or attribute.name.startswith("_") # e.g. _func
293+
or isinstance(value, LazyField)
294+
or tp_attr in [ty.Any, inspect._empty]
295+
):
296+
check_type = False # no checking of the type
297+
elif isinstance(tp_attr, type) or tp_attr in [File, Directory]:
298+
tp = _single_type_update(tp_attr, name=attribute.name)
299+
cont_type = None
300+
else: # more complex types
301+
cont_type, tp_attr_list = _check_special_type(tp_attr, name=attribute.name)
302+
if cont_type is ty.Union:
303+
tp, check_type = _types_updates(tp_attr_list, name=attribute.name)
304+
elif cont_type is list:
305+
tp, check_type = _types_updates(tp_attr_list, name=attribute.name)
306+
elif cont_type is dict:
307+
# assuming that it should have length of 2 for keys and values
308+
if len(tp_attr_list) != 2:
309+
check_type = False
310+
else:
311+
tp_attr_key, tp_attr_val = tp_attr_list
312+
# updating types separately for keys and values
313+
tp_k, check_k = _types_updates([tp_attr_key], name=attribute.name)
314+
tp_v, check_v = _types_updates([tp_attr_val], name=attribute.name)
315+
# assuming that I have to be able to check keys and values
316+
if not (check_k and check_v):
317+
check_type = False
318+
else:
319+
tp = {"key": tp_k, "val": tp_v}
320+
else:
321+
warnings.warn(
322+
f"no type check for {attribute.name} field, no type check implemented for value {value} and type {tp_attr}"
323+
)
324+
check_type = False
325+
326+
if check_type:
327+
validators.append(_type_validator(instance, attribute, value, tp, cont_type))
328+
329+
# checking additional requirements for values (e.g. allowed_values)
330+
meta_attr = attribute.metadata
331+
if "allowed_values" in meta_attr:
332+
validators.append(_allowed_values_validator(isinstance, attribute, value))
333+
return validators
334+
335+
336+
def _type_validator(instance, attribute, value, tp, cont_type):
337+
""" creating a customized type validator,
338+
uses validator.deep_iterable/mapping if the field is a container
339+
(i.e. ty.List or ty.Dict),
340+
it also tries to guess when the value is a list due to the splitter
341+
and validates the elements
342+
"""
343+
if cont_type is None or cont_type is ty.Union:
344+
# if tp is not (list,), we are assuming that the value is a list
345+
# due to the splitter, so checking the member types
346+
if isinstance(value, list) and tp != (list,):
347+
return attr.validators.deep_iterable(
348+
member_validator=attr.validators.instance_of(
349+
tp + (attr._make._Nothing,)
350+
)
351+
)(instance, attribute, value)
352+
else:
353+
return attr.validators.instance_of(tp + (attr._make._Nothing,))(
354+
instance, attribute, value
355+
)
356+
elif cont_type is list:
357+
return attr.validators.deep_iterable(
358+
member_validator=attr.validators.instance_of(tp + (attr._make._Nothing,))
359+
)(instance, attribute, value)
360+
elif cont_type is dict:
361+
return attr.validators.deep_mapping(
362+
key_validator=attr.validators.instance_of(tp["key"]),
363+
value_validator=attr.validators.instance_of(
364+
tp["val"] + (attr._make._Nothing,)
365+
),
366+
)(instance, attribute, value)
367+
else:
368+
raise Exception(
369+
f"container type of {attribute.name} should be None, list, dict or ty.Union, and not {cont_type}"
370+
)
371+
372+
373+
def _types_updates(tp_list, name):
374+
"""updating the type's tuple with possible additional types"""
375+
tp_upd_list = []
376+
check = True
377+
for tp_el in tp_list:
378+
tp_upd = _single_type_update(tp_el, name, simplify=True)
379+
if tp_upd is None:
380+
check = False
381+
break
382+
else:
383+
tp_upd_list += list(tp_upd)
384+
tp_upd = tuple(set(tp_upd_list))
385+
return tp_upd, check
386+
387+
388+
def _single_type_update(tp, name, simplify=False):
389+
""" updating a single type with other related types - e.g. adding bytes for str
390+
if simplify is True, than changing typing.List to list etc.
391+
(assuming that I validate only one depth, so have to simplify at some point)
392+
"""
393+
if isinstance(tp, type) or tp in [File, Directory]:
394+
if tp is str:
395+
return (str, bytes)
396+
elif tp in [File, Directory, os.PathLike]:
397+
return (os.PathLike, str)
398+
elif tp is float:
399+
return (float, int)
400+
else:
401+
return (tp,)
402+
elif simplify is True:
403+
warnings.warn(f"simplify validator for {name} field, checking only one depth")
404+
cont_tp, types_list = _check_special_type(tp, name=name)
405+
if cont_tp is list:
406+
return (list,)
407+
elif cont_tp is dict:
408+
return (dict,)
409+
elif cont_tp is ty.Union:
410+
return types_list
411+
else:
412+
warnings.warn(
413+
f"no type check for {name} field, type check not implemented for type of {tp}"
414+
)
415+
return None
416+
else:
417+
warnings.warn(
418+
f"no type check for {name} field, type check not implemented for type - {tp}, consider using simplify=True"
419+
)
420+
return None
421+
422+
423+
def _check_special_type(tp, name):
424+
"""checking if the type is a container: ty.List, ty.Dict or ty.Union """
425+
if sys.version_info.minor >= 8:
426+
return ty.get_origin(tp), ty.get_args(tp)
427+
else:
428+
if isinstance(tp, type): # simple type
429+
return None, ()
430+
else:
431+
if tp._name == "List":
432+
return list, tp.__args__
433+
elif tp._name == "Dict":
434+
return dict, tp.__args__
435+
elif tp.__origin__ is ty.Union:
436+
return ty.Union, tp.__args__
437+
else:
438+
warnings.warn(
439+
f"not type check for {name} field, type check not implemented for type {tp}"
440+
)
441+
return None, ()
442+
443+
444+
def _allowed_values_validator(instance, attribute, value):
445+
""" checking if the values is in allowed_values"""
446+
allowed = attribute.metadata["allowed_values"]
447+
if value is attr.NOTHING:
448+
pass
449+
elif value not in allowed:
450+
raise ValueError(
451+
f"value of {attribute.name} has to be from {allowed}, but {value} provided"
452+
)
453+
454+
265455
async def read_stream_and_display(stream, display):
266456
"""
267457
Read from stream line by line until EOF, display, and capture the lines.

pydra/engine/specs.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -293,26 +293,11 @@ def check_fields_input_spec(self):
293293
if required_notfound:
294294
raise AttributeError(f"{nm} requires {required_notfound}")
295295

296-
# TODO: types might be checked here
297-
self._type_checking()
298-
299296
def _file_check(self, field):
300297
file = Path(getattr(self, field.name))
301298
if not file.exists():
302299
raise AttributeError(f"the file from the {field.name} input does not exist")
303300

304-
def _type_checking(self):
305-
"""Use fld.type to check the types TODO.
306-
307-
This may be done through attr validators.
308-
309-
"""
310-
fields = attr_fields(self)
311-
allowed_keys = ["min_val", "max_val", "range", "enum"] # noqa
312-
for fld in fields:
313-
# TODO
314-
pass
315-
316301

317302
@attr.s(auto_attribs=True, kw_only=True)
318303
class ShellOutSpec(BaseSpec):

pydra/engine/tests/test_shelltask.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from ..submitter import Submitter
1010
from ..core import Workflow
1111
from ..specs import ShellOutSpec, ShellSpec, SpecInfo, File
12-
from .utils import result_no_submitter, result_submitter
13-
12+
from .utils import result_no_submitter, result_submitter, use_validator
1413

1514
if sys.platform.startswith("win"):
1615
pytest.skip("SLURM not available in windows", allow_module_level=True)
@@ -251,7 +250,7 @@ def test_wf_shell_cmd_1(plugin):
251250

252251

253252
@pytest.mark.parametrize("results_function", [result_no_submitter, result_submitter])
254-
def test_shell_cmd_inputspec_1(plugin, results_function):
253+
def test_shell_cmd_inputspec_1(plugin, results_function, use_validator):
255254
""" a command with executable, args and one command opt,
256255
using a customized input_spec to add the opt to the command
257256
in the right place that is specified in metadata["cmd_pos"]
@@ -290,7 +289,7 @@ def test_shell_cmd_inputspec_1(plugin, results_function):
290289

291290

292291
@pytest.mark.parametrize("results_function", [result_no_submitter, result_submitter])
293-
def test_shell_cmd_inputspec_2(plugin, results_function):
292+
def test_shell_cmd_inputspec_2(plugin, results_function, use_validator):
294293
""" a command with executable, args and two command options,
295294
using a customized input_spec to add the opt to the command
296295
in the right place that is specified in metadata["cmd_pos"]
@@ -1513,6 +1512,50 @@ def test_shell_cmd_inputspec_state_1(plugin, results_function):
15131512
assert res[1].output.stdout == "hi\n"
15141513

15151514

1515+
def test_shell_cmd_inputspec_typeval_1(use_validator):
1516+
""" customized input_spec with a type that doesn't match the value
1517+
- raise an exception
1518+
"""
1519+
cmd_exec = "echo"
1520+
1521+
my_input_spec = SpecInfo(
1522+
name="Input",
1523+
fields=[
1524+
(
1525+
"text",
1526+
attr.ib(
1527+
type=int,
1528+
metadata={"position": 1, "argstr": "", "help_string": "text"},
1529+
),
1530+
)
1531+
],
1532+
bases=(ShellSpec,),
1533+
)
1534+
1535+
with pytest.raises(TypeError):
1536+
shelly = ShellCommandTask(
1537+
executable=cmd_exec, text="hello", input_spec=my_input_spec
1538+
)
1539+
1540+
1541+
def test_shell_cmd_inputspec_typeval_2(use_validator):
1542+
""" customized input_spec (shorter syntax) with a type that doesn't match the value
1543+
- raise an exception
1544+
"""
1545+
cmd_exec = "echo"
1546+
1547+
my_input_spec = SpecInfo(
1548+
name="Input",
1549+
fields=[("text", int, {"position": 1, "argstr": "", "help_string": "text"})],
1550+
bases=(ShellSpec,),
1551+
)
1552+
1553+
with pytest.raises(TypeError):
1554+
shelly = ShellCommandTask(
1555+
executable=cmd_exec, text="hello", input_spec=my_input_spec
1556+
)
1557+
1558+
15161559
@pytest.mark.parametrize("results_function", [result_no_submitter, result_submitter])
15171560
def test_shell_cmd_inputspec_state_1a(plugin, results_function):
15181561
""" adding state to the input from input_spec

0 commit comments

Comments
 (0)