Skip to content

Commit bcbca9c

Browse files
authored
do not add type of default if it is already in type (#40)
1 parent fb3cd72 commit bcbca9c

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

dargs/dargs.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,11 @@ def _reorg_dtype(self):
224224
# check conner cases
225225
if self.sub_fields or self.sub_variants:
226226
self.dtype.add(list if self.repeat else dict)
227-
if self.optional and self.default is not _Flags.NONE:
227+
if (
228+
self.optional
229+
and self.default is not _Flags.NONE
230+
and all([not isinstance_annotation(self.default, tt) for tt in self.dtype])
231+
):
228232
self.dtype.add(type(self.default))
229233
# and make it compatible with `isinstance`
230234
self.dtype = tuple(self.dtype)
@@ -968,6 +972,19 @@ def trim_by_pattern(
968972
argdict.pop(key)
969973

970974

975+
def isinstance_annotation(value, dtype) -> bool:
976+
"""Same as isinstance(), but supports arbitrary type annotations."""
977+
try:
978+
typeguard.check_type(
979+
value,
980+
dtype,
981+
collection_check_strategy=typeguard.CollectionCheckStrategy.ALL_ITEMS,
982+
)
983+
except typeguard.TypeCheckError as e:
984+
return False
985+
return True
986+
987+
971988
class ArgumentEncoder(json.JSONEncoder):
972989
"""Extended JSON Encoder to encode Argument object:
973990

dargs/sphinx.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,5 +192,11 @@ def _test_arguments() -> List[Argument]:
192192
return [
193193
Argument(name="test1", dtype=int, doc="Argument 1"),
194194
Argument(name="test2", dtype=[float, None], doc="Argument 2"),
195-
Argument(name="test3", dtype=List[str], doc="Argument 3"),
195+
Argument(
196+
name="test3",
197+
dtype=List[str],
198+
default=["test"],
199+
optional=True,
200+
doc="Argument 3",
201+
),
196202
]

tests/test_checker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def test_name_type(self):
3131
# list[int]
3232
ca = Argument("key1", List[float])
3333
ca.check({"key1": [1, 2.0, 3]})
34+
with self.assertRaises(ArgumentTypeError):
35+
ca.check({"key1": [1, 2.0, "3"]})
36+
ca = Argument("key1", List[float], default=[], optional=True)
3437
with self.assertRaises(ArgumentTypeError):
3538
ca.check({"key1": [1, 2.0, "3"]})
3639
# optional case

0 commit comments

Comments
 (0)