Skip to content

Commit d557ca0

Browse files
feat: did you know for variant chocie (#47)
Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c7f79a3 commit d557ca0

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

dargs/dargs.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
We also need to pay special attention to flat the keys of its choices.
1717
"""
1818

19+
import difflib
1920
import fnmatch
2021
import json
2122
import re
@@ -800,7 +801,12 @@ def get_choice(self, argdict: dict, path=None) -> "Argument":
800801
return self.choice_dict[self.choice_alias[tag]]
801802
else:
802803
raise ArgumentValueError(
803-
path, f"get invalid choice `{tag}` for flag key `{self.flag_name}`."
804+
path,
805+
f"get invalid choice `{tag}` for flag key `{self.flag_name}`."
806+
+ did_you_mean(
807+
tag,
808+
list(self.choice_dict.keys()) + list(self.choice_alias.keys()),
809+
),
804810
)
805811
elif self.optional:
806812
return self.choice_dict[self.default_tag]
@@ -1042,3 +1048,22 @@ def default(self, obj) -> Dict[str, Union[str, bool, List]]:
10421048
elif isinstance(obj, type):
10431049
return obj.__name__
10441050
return json.JSONEncoder.default(self, obj)
1051+
1052+
1053+
def did_you_mean(choice: str, choices: List[str]) -> str:
1054+
"""Get did you mean message.
1055+
1056+
Parameters
1057+
----------
1058+
choice : str
1059+
the user's wrong choice
1060+
choices : list[str]
1061+
all the choices
1062+
1063+
Returns
1064+
-------
1065+
str
1066+
did you mean error message
1067+
"""
1068+
matches = difflib.get_close_matches(choice, choices)
1069+
return f"Did you mean: {matches[0]}?" if matches else ""

tests/test_checker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,9 @@ def test_sub_variants(self):
249249
"vnt2_1": 21,
250250
}
251251
}
252-
with self.assertRaises(ArgumentValueError):
252+
with self.assertRaises(ArgumentValueError) as cm:
253253
ca.check(err_dict2)
254+
self.assertIn("Did you mean: type3?", str(cm.exception))
254255
# test optional choice
255256
test_dict1["base"].pop("vnt_flag")
256257
with self.assertRaises(ArgumentKeyError):

0 commit comments

Comments
 (0)