Skip to content

Commit e9c749f

Browse files
committed
Add tests for available models and config helper
1 parent bdee6c3 commit e9c749f

File tree

3 files changed

+348
-34
lines changed

3 files changed

+348
-34
lines changed

askai/utils.py

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pathlib import Path
2+
13
import yaml
24
import click
35
import openai
@@ -47,22 +49,22 @@ class ConfigHelper:
4749
presence_penalty: float = DEFAULT_PRESENCE_PENALTY
4850

4951
@classmethod
50-
def from_file(cls) -> 'ConfigHelper':
51-
if CONFIG_PATH.is_file():
52-
with open(CONFIG_PATH, "r") as f:
52+
def from_file(cls, config_path: Path = CONFIG_PATH) -> 'ConfigHelper':
53+
if config_path.is_file():
54+
with open(config_path, "r") as f:
5355
config = yaml.safe_load(f)
5456
return cls(**config)
5557
else:
5658
click.echo(click.style("No config file found, can't initialize config. "
5759
"Run 'askai config reset' to create a default config.", fg="red"))
5860
exit()
5961

60-
def input_model(self) -> None:
62+
def input_model(self, max_input_tries: int = MAX_INPUT_TRIES) -> None:
6163
model = input("Choose model (1-4): ")
6264
num_of_tries = 1
6365

6466
while not _is_int(model) or int(model) not in range(1, 5):
65-
if num_of_tries >= MAX_INPUT_TRIES:
67+
if num_of_tries >= max_input_tries:
6668
click.echo(click.style("Too many invalid tries. Aborted!", fg="red"))
6769
exit(1)
6870

@@ -74,56 +76,84 @@ def input_model(self) -> None:
7476
click.echo(click.style(f"Model chosen: {self.model}", fg="green"))
7577
click.echo()
7678

77-
def input_num_answer(self) -> None:
79+
def input_num_answer(self,
80+
default_value: int = DEFAULT_NUM_ANSWERS,
81+
min_value: int = OPENAI_NUM_ANSWERS_MIN,
82+
max_input_tries: int = MAX_INPUT_TRIES) -> None:
7883
self.num_answers = self._input_integer(
79-
default_value=DEFAULT_NUM_ANSWERS,
80-
predicate=lambda x: x >= OPENAI_NUM_ANSWERS_MIN
84+
default_value=default_value,
85+
predicate=lambda x: x >= min_value,
86+
max_input_tries=max_input_tries
8187
)
8288

83-
def input_max_token(self) -> None:
89+
def input_max_token(self,
90+
default_value: int = DEFAULT_MAX_TOKENS,
91+
min_value: int = OPENAI_MAX_TOKENS_MIN,
92+
max_input_tries: int = MAX_INPUT_TRIES) -> None:
8493
self.max_tokens = self._input_integer(
85-
default_value=DEFAULT_MAX_TOKENS,
86-
predicate=lambda x: x >= OPENAI_MAX_TOKENS_MIN
94+
default_value=default_value,
95+
predicate=lambda x: x >= min_value,
96+
max_input_tries=max_input_tries
8797
)
8898

89-
def input_temperature(self) -> None:
99+
def input_temperature(self,
100+
default_value: float = DEFAULT_TEMPERATURE,
101+
min_value: float = OPENAI_TEMPERATURE_MIN,
102+
max_value: float = OPENAI_TEMPERATURE_MAX,
103+
max_input_tries: int = MAX_INPUT_TRIES) -> None:
90104
self.temperature = self._input_float(
91-
default_value=DEFAULT_TEMPERATURE,
92-
predicate=lambda x: OPENAI_TEMPERATURE_MIN <= x <= OPENAI_TEMPERATURE_MAX
105+
default_value=default_value,
106+
predicate=lambda x: min_value <= x <= max_value,
107+
max_input_tries=max_input_tries
93108
)
94109

95-
def input_top_p(self) -> None:
110+
def input_top_p(self,
111+
default_value: float = DEFAULT_TOP_P,
112+
min_value: float = OPENAI_TOP_P_MIN,
113+
max_value: float = OPENAI_TOP_P_MAX,
114+
max_input_tries: int = MAX_INPUT_TRIES) -> None:
96115
self.top_p = self._input_float(
97-
default_value=DEFAULT_TOP_P,
98-
predicate=lambda x: OPENAI_TOP_P_MIN <= x <= OPENAI_TOP_P_MAX
116+
default_value=default_value,
117+
predicate=lambda x: min_value <= x <= max_value,
118+
max_input_tries=max_input_tries
99119
)
100120

101-
def input_frequency_penalty(self) -> None:
121+
def input_frequency_penalty(self,
122+
default_value: float = DEFAULT_FREQUENCY_PENALTY,
123+
min_value: float = OPENAI_FREQUENCY_PENALTY_MIN,
124+
max_value: float = OPENAI_FREQUENCY_PENALTY_MAX,
125+
max_input_tries: int = MAX_INPUT_TRIES) -> None:
102126
self.frequency_penalty = self._input_float(
103-
default_value=DEFAULT_FREQUENCY_PENALTY,
104-
predicate=lambda x: OPENAI_FREQUENCY_PENALTY_MIN <= x <= OPENAI_FREQUENCY_PENALTY_MAX
127+
default_value=default_value,
128+
predicate=lambda x: min_value <= x <= max_value,
129+
max_input_tries=max_input_tries
105130
)
106131

107-
def input_presence_penalty(self) -> None:
132+
def input_presence_penalty(self,
133+
default_value: float = DEFAULT_PRESENCE_PENALTY,
134+
min_value: float = OPENAI_PRESENCE_PENALTY_MIN,
135+
max_value: float = OPENAI_PRESENCE_PENALTY_MAX,
136+
max_input_tries: int = MAX_INPUT_TRIES) -> None:
108137
self.presence_penalty = self._input_float(
109-
default_value=DEFAULT_PRESENCE_PENALTY,
110-
predicate=lambda x: OPENAI_PRESENCE_PENALTY_MIN <= x <= OPENAI_PRESENCE_PENALTY_MAX
138+
default_value=default_value,
139+
predicate=lambda x: min_value <= x <= max_value,
140+
max_input_tries=max_input_tries
111141
)
112142

113143
def as_dict(self) -> dict:
114144
return asdict(self)
115145

116-
def update(self) -> None:
146+
def update(self, config_path: Path = CONFIG_PATH) -> None:
117147
config = self.as_dict()
118-
with open(CONFIG_PATH, "w") as f:
148+
with open(config_path, "w") as f:
119149
yaml.dump(config, f)
120150

121151
click.echo(click.style("Config updated successfully!", fg="green"))
122152

123153
@staticmethod
124-
def reset() -> None:
154+
def reset(config_path: Path = CONFIG_PATH) -> None:
125155
config = ConfigHelper().as_dict() # Create config with default values
126-
with open(CONFIG_PATH, "w") as f:
156+
with open(config_path, "w") as f:
127157
yaml.dump(config, f)
128158

129159
click.echo("\nDefault config has been created with the following values:")
@@ -133,12 +163,12 @@ def reset() -> None:
133163
click.echo("To change the config, please see: 'askai config --help'\n")
134164

135165
@staticmethod
136-
def show() -> None:
137-
if not CONFIG_PATH.is_file():
166+
def show(config_path: Path = CONFIG_PATH) -> None:
167+
if not config_path.is_file():
138168
click.echo("No config exists. Please reset the config ('askai config reset') "
139169
"or see 'askai config --help'.\n")
140170
else:
141-
with open(CONFIG_PATH, "r") as f:
171+
with open(config_path, "r") as f:
142172
try:
143173
config = yaml.safe_load(f)
144174
for key, value in config.items():
@@ -148,8 +178,9 @@ def show() -> None:
148178

149179
@staticmethod
150180
def _input_integer(default_value: int,
151-
predicate: Callable[[int], bool] = lambda x: True) -> int:
152-
for _ in range(MAX_INPUT_TRIES):
181+
predicate: Callable[[int], bool] = lambda x: True,
182+
max_input_tries: int = MAX_INPUT_TRIES) -> int:
183+
for _ in range(max_input_tries):
153184
input_value = input(f"Choose (press enter for default = {default_value}): ")
154185

155186
if input_value == "":
@@ -173,8 +204,9 @@ def _input_integer(default_value: int,
173204

174205
@staticmethod
175206
def _input_float(default_value: float,
176-
predicate: Callable[[float], bool] = lambda x: True) -> float:
177-
for _ in range(MAX_INPUT_TRIES):
207+
predicate: Callable[[float], bool] = lambda x: True,
208+
max_input_tries: int = MAX_INPUT_TRIES) -> float:
209+
for _ in range(max_input_tries):
178210
input_value = input(f"Choose (press enter for default = {default_value}): ")
179211

180212
if input_value == "":

tests/test_available_models.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from askai.utils import AvailableModels
2+
3+
4+
def test_name_values() -> None:
5+
assert "TEXT_ADA_001" in AvailableModels.__members__
6+
assert "TEXT_BABBAGE_001" in AvailableModels.__members__
7+
assert "TEXT_CURIE_001" in AvailableModels.__members__
8+
assert "TEXT_DAVINCI_003" in AvailableModels.__members__
9+
10+
11+
def test_members_as_list() -> None:
12+
enum_list = AvailableModels.members_as_list(openai_style=False)
13+
assert type(enum_list) == list
14+
assert len(AvailableModels) == len(enum_list)
15+
for name in AvailableModels.__members__:
16+
assert name in enum_list
17+
18+
19+
def test_members_as_list_openai_style() -> None:
20+
enum_list = AvailableModels.members_as_list(openai_style=True)
21+
assert type(enum_list) == list
22+
assert len(AvailableModels) == len(enum_list)
23+
for name in AvailableModels.__members__:
24+
assert name.replace("_", "-").lower() in enum_list
25+

0 commit comments

Comments
 (0)