Skip to content

Commit 1eccbef

Browse files
committed
Cleanup role validation and block unknown roles in read, thanks @RobGallo!
1 parent 6ef0ec4 commit 1eccbef

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/gptcmd/cli.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,15 @@ def _confirm(prompt: str) -> bool:
132132
def _complete_from_key(d: Dict, text: str) -> List[str]:
133133
return [k for k, v in d.items() if k.startswith(text)]
134134

135-
@staticmethod
136-
def _complete_role(text: str) -> List[str]:
137-
ROLES = ("user", "assistant", "system")
138-
return [role for role in ROLES if role.startswith(text)]
135+
KNOWN_ROLES = ("user", "assistant", "system")
136+
137+
@classmethod
138+
def _complete_role(cls, text: str) -> List[str]:
139+
return [role for role in cls.ROLES if role.startswith(text)]
140+
141+
@classmethod
142+
def _validate_role(cls, role: str) -> bool:
143+
return role in cls.KNOWN_ROLES
139144

140145
def emptyline(self):
141146
"Disable Python cmd's repeat last command behaviour."
@@ -633,15 +638,16 @@ def do_rename(self, arg):
633638
"""
634639
m = re.match(
635640
(
636-
r"^(user|assistant|system)\s+"
641+
f"^({'|'.join(self.__class__.KNOWN_ROLES)})\s+"
637642
r"((?:-?\d+|\.)(?:\s+-?\d+|\s*\.)*)"
638643
r"(?:\s+([a-zA-Z0-9_-]{1,64}))?$"
639644
),
640645
arg,
641646
)
642647
if not m:
643648
print(
644-
"Usage: rename <user|assistant|system> <message range> [name]"
649+
f"Usage: rename <{'|'.join(self.__class__.KNOWN_ROLES)}>"
650+
" <message range> [name]"
645651
)
646652
return
647653
role, ref, name = m.groups()
@@ -775,10 +781,17 @@ def do_read(self, arg):
775781
"""
776782
args = arg.split()
777783
if len(args) < 2:
778-
print("Usage: read <path> <user|assistant|system>")
784+
print(
785+
"Usage: read <path> <{'|'.join(self.__class__.KNOWN_ROLES)}>"
786+
)
779787
return
780788
path = " ".join(args[:-1])
781789
role = args[-1]
790+
if not self.__class__._validate_role(role):
791+
print(
792+
f"Usage: read <path> <{'|'.join(self.__class__.KNOWN_ROLES)}>"
793+
)
794+
return
782795
try:
783796
with open(path, encoding="utf-8", errors="ignore") as fin:
784797
self._current_thread.append(

0 commit comments

Comments
 (0)