Skip to content

Commit d32eae0

Browse files
committed
clean up code (with a bug left)
1 parent fe51315 commit d32eae0

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

dargs/dargs.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
r"""
22
Some (ocaml) pseudo-code here to show the intended type structure::
3-
3+
44
type args = {key: str; value: data; optional: bool; doc: str} list
55
and data =
66
| Arg of dtype
@@ -19,7 +19,7 @@
1919

2020

2121
from typing import Union, Any, List, Iterable, Optional, Callable
22-
from textwrap import wrap, fill, indent
22+
from textwrap import indent
2323
from copy import deepcopy
2424
from enum import Enum
2525
import fnmatch, re
@@ -30,6 +30,7 @@
3030
class _Flags(Enum): NONE = 0 # for no value in dict
3131

3232
class Argument:
33+
"""Define possible arguments and their types and properties."""
3334

3435
def __init__(self,
3536
name: str,
@@ -44,7 +45,9 @@ def __init__(self,
4445
doc: str = ""):
4546
self.name = name
4647
self.dtype = dtype
48+
assert sub_fields is None or all(isinstance(s, Argument) for s in sub_fields)
4749
self.sub_fields = sub_fields if sub_fields is not None else []
50+
assert sub_variants is None or all(isinstance(s, Variant) for s in sub_variants)
4851
self.sub_variants = sub_variants if sub_variants is not None else []
4952
self.repeat = repeat
5053
self.optional = optional
@@ -67,6 +70,9 @@ def __eq__(self, other: "Argument") -> bool:
6770
and self.repeat == other.repeat
6871
and self.optional == other.optional)
6972

73+
def __repr__(self) -> str:
74+
return f"<Argument {self.name}: {' | '.join(dd.__name__ for dd in self.dtype)}>"
75+
7076
def reorg_dtype(self):
7177
if isinstance(self.dtype, type) or self.dtype is None:
7278
self.dtype = [self.dtype]
@@ -145,7 +151,7 @@ def traverse_value(self, value: Any,
145151
key_hook, value_hook, sub_hook, variant_hook)
146152
self._traverse_subvariant(item,
147153
key_hook, value_hook, sub_hook, variant_hook)
148-
154+
149155
def _traverse_subfield(self, value: dict, *args, **kwargs):
150156
assert isinstance(value, dict)
151157
for subarg in self.sub_fields:
@@ -174,7 +180,7 @@ def check_value(self, argdict: dict, strict: bool = False):
174180
key_hook=Argument._check_exist,
175181
value_hook=Argument._check_value,
176182
sub_hook=Argument._check_strict if strict else DUMMYHOOK)
177-
183+
178184
def _check_exist(self, argdict: dict):
179185
if self.optional is True:
180186
return
@@ -198,7 +204,7 @@ def _check_strict(self, value: dict):
198204
if name not in allowed_set:
199205
raise KeyError(f"undefined key `{name}` is "
200206
"not allowed in strict mode")
201-
207+
202208
def _get_allowed_sub(self, value: dict) -> List[str]:
203209
allowed = [subarg.name for subarg in self.sub_fields]
204210
for subvrnt in self.sub_variants:
@@ -322,9 +328,10 @@ def gen_doc_body(self, paths: Optional[List[str]] = None, **kwargs) -> str:
322328
body_list.append(subvrnt.gen_doc(paths, **kwargs))
323329
body = "\n".join(body_list)
324330
return body
325-
331+
326332

327333
class Variant:
334+
"""Define multiple choices of possible argument sets."""
328335

329336
def __init__(self,
330337
flag_name: str,
@@ -349,6 +356,9 @@ def __eq__(self, other: "Variant") -> bool:
349356
and self.choice_dict == other.choice_dict
350357
and self.optional == other.optional
351358
and self.default_tag == other.default_tag)
359+
360+
def __repr__(self) -> str:
361+
return f"<Variant {self.flag_name} in {{ {', '.join(self.choice_dict.keys())} }}>"
352362

353363
def extend_choices(self, choices: Iterable["Argument"]):
354364
# choices is a list of arguments
@@ -398,7 +408,7 @@ def traverse(self, argdict: dict,
398408
variant_hook: Callable[["Variant", dict], None] = DUMMYHOOK):
399409
variant_hook(self, argdict)
400410
choice = self._load_choice(argdict)
401-
# here we use check_value to flatten the tag
411+
# here we use traverse_value to flatten the tag
402412
choice.traverse_value(argdict,
403413
key_hook, value_hook, sub_hook, variant_hook)
404414

@@ -430,7 +440,7 @@ def _convert_alias(self, argdict: dict):
430440

431441
# above are type checking part
432442
# below are doc generation part
433-
443+
434444
def gen_doc(self, paths: Optional[List[str]] = None, **kwargs) -> str:
435445
body_list = [""]
436446
body_list.append(f"Depending on the value of *{self.flag_name}*, "

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
url="https://github.com/deepmodeling/dargs",
2626
packages=['dargs'],
2727
classifiers=[
28-
"Programming Language :: Python :: 3.6",
28+
"Programming Language :: Python :: 3.7",
2929
"License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
3030
],
3131
install_requires=install_requires,

tests/test_checker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def test_sub_variants(self):
171171
with self.assertRaises(KeyError):
172172
ca.check(err_dict1)
173173
err_dict1["base"]["vnt_flag"] = "type1"
174+
ca.check(err_dict1, strict=True) # this should pass
174175
err_dict1["base"]["additional"] = "hahaha"
175176
ca.check(err_dict1) # now should pass
176177
with self.assertRaises(KeyError):

0 commit comments

Comments
 (0)