Skip to content

Commit 675c688

Browse files
committed
fix bug in error msg
1 parent 63540ca commit 675c688

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

dargs/dargs.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def __init__(self,
8181
default: Any = _Flags.NONE,
8282
alias: Optional[Iterable[str]] = None,
8383
extra_check: Optional[Callable[[Any], bool]] = None,
84-
doc: str = ""):
84+
doc: str = "",
85+
fold_subdoc: bool = False):
8586
self.name = name
8687
self.dtype = dtype
8788
self.sub_fields : Dict[str, "Argument"] = {}
@@ -92,6 +93,7 @@ def __init__(self,
9293
self.alias = alias if alias is not None else []
9394
self.extra_check = extra_check
9495
self.doc = doc
96+
self.fold_subdoc = fold_subdoc
9597
# adding subfields and subvariants
9698
self.extend_subfields(sub_fields)
9799
self.extend_subvariants(sub_variants)
@@ -218,9 +220,12 @@ def traverse(self, argdict: dict,
218220
if path is None: path = []
219221
key_hook(self, argdict, path)
220222
if self.name in argdict:
223+
value = argdict[self.name]
224+
value_hook(self, value, path)
225+
newpath = [*path, self.name]
221226
# this is the key step that we traverse into the tree
222-
self.traverse_value(argdict[self.name],
223-
key_hook, value_hook, sub_hook, variant_hook, path)
227+
self.traverse_value(value,
228+
key_hook, value_hook, sub_hook, variant_hook, newpath)
224229

225230
def traverse_value(self, value: Any,
226231
key_hook: HookArgKType = _DUMMYHOOK,
@@ -231,7 +236,6 @@ def traverse_value(self, value: Any,
231236
# this is not private, and can be called directly
232237
# in the condition where there is no leading key
233238
if path is None: path = []
234-
value_hook(self, value, path)
235239
if isinstance(value, dict):
236240
self._traverse_sub(value,
237241
key_hook, value_hook, sub_hook, variant_hook, path)
@@ -247,14 +251,13 @@ def _traverse_sub(self, value: dict,
247251
variant_hook: HookVrntType = _DUMMYHOOK,
248252
path: Optional[List[str]] = None):
249253
assert isinstance(value, dict)
250-
if path is None: path = []
254+
if path is None: path = [self.name]
251255
sub_hook(self, value, path)
252256
for subvrnt in self.sub_variants.values():
253257
variant_hook(subvrnt, value, path)
254-
newpath = [*path, self.name]
255258
for subarg in self.flatten_sub(value, path).values():
256259
subarg.traverse(value,
257-
key_hook, value_hook, sub_hook, variant_hook, newpath)
260+
key_hook, value_hook, sub_hook, variant_hook, path)
258261

259262
# above are general traverse part
260263
# below are type checking part
@@ -297,6 +300,7 @@ def _check_value(self, value: Any, path=None):
297300

298301
def _check_strict(self, value: dict, path=None):
299302
allowed_keys = self.flatten_sub(value, path).keys()
303+
# curpath = [*path, self.name]
300304
for name in value.keys():
301305
if name not in allowed_keys:
302306
raise ArgumentKeyError(path,
@@ -408,18 +412,19 @@ def gen_doc_body(self, path: Optional[List[str]] = None, **kwargs) -> str:
408412
body_list = []
409413
if self.doc:
410414
body_list.append(self.doc + "\n")
411-
if self.repeat:
412-
body_list.append("This argument takes a list with "
413-
"each element containing the following: \n")
414-
if self.sub_fields:
415-
# body_list.append("") # genetate a blank line
416-
# body_list.append("This argument accept the following sub arguments:")
417-
for subarg in self.sub_fields.values():
418-
body_list.append(subarg.gen_doc(path, **kwargs))
419-
if self.sub_variants:
420-
showflag = len(self.sub_variants) > 1
421-
for subvrnt in self.sub_variants.values():
422-
body_list.append(subvrnt.gen_doc(path, showflag, **kwargs))
415+
if not self.fold_subdoc:
416+
if self.repeat:
417+
body_list.append("This argument takes a list with "
418+
"each element containing the following: \n")
419+
if self.sub_fields:
420+
# body_list.append("") # genetate a blank line
421+
# body_list.append("This argument accept the following sub arguments:")
422+
for subarg in self.sub_fields.values():
423+
body_list.append(subarg.gen_doc(path, **kwargs))
424+
if self.sub_variants:
425+
showflag = len(self.sub_variants) > 1
426+
for subvrnt in self.sub_variants.values():
427+
body_list.append(subvrnt.gen_doc(path, showflag, **kwargs))
423428
body = "\n".join(body_list)
424429
return body
425430

0 commit comments

Comments
 (0)