Skip to content

Commit 84ba019

Browse files
committed
Fix windows installation and add tests
1 parent 8e9b090 commit 84ba019

11 files changed

+497
-47
lines changed

askai/entrypoint_config.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from pathlib import Path
2+
13
import click
24

5+
from .constants import CONFIG_PATH
36
from .utils import ConfigHelper, PrintHelper
47

58

@@ -11,10 +14,15 @@ def config() -> None:
1114

1215
@config.command(help="Reset the config to default")
1316
def reset() -> None:
17+
_reset()
18+
19+
20+
def _reset(config_path: Path = CONFIG_PATH) -> None:
21+
"""Separate function for testing"""
1422
user_verification = input("Do you want to reset your config to the default values? [y/Y]? ")
1523

1624
if user_verification.lower() in ["y", "yes"]:
17-
ConfigHelper().reset()
25+
ConfigHelper().reset(config_path=config_path)
1826
else:
1927
click.echo("Config not reset. Aborted!")
2028

@@ -31,6 +39,11 @@ def update() -> None:
3139

3240
@update.command("all", help="Interface to update the full default config")
3341
def update_all() -> None:
42+
_update_all()
43+
44+
45+
def _update_all(config_path: Path = CONFIG_PATH) -> None:
46+
"""Separate function for testing"""
3447
config_helper = ConfigHelper()
3548
PrintHelper.update_config()
3649

@@ -62,60 +75,95 @@ def update_all() -> None:
6275
PrintHelper.presence_penalty()
6376
config_helper.input_presence_penalty()
6477

65-
config_helper.update()
78+
config_helper.update(config_path=config_path)
6679

6780

6881
@update.command(help="Update model")
6982
def model() -> None:
83+
_model()
84+
85+
86+
def _model(config_path: Path = CONFIG_PATH) -> None:
87+
"""Separate function for testing"""
7088
PrintHelper.model()
7189
config_helper = ConfigHelper.from_file()
7290
config_helper.input_model()
73-
config_helper.update()
91+
config_helper.update(config_path=config_path)
7492

7593

7694
@update.command(help="Update number of altenative answers generated per question")
7795
def num_answers() -> None:
96+
_num_answers()
97+
98+
99+
def _num_answers(config_path: Path = CONFIG_PATH) -> None:
100+
"""Separate function for testing"""
78101
PrintHelper.num_answers()
79102
config_helper = ConfigHelper.from_file()
80103
config_helper.input_num_answer()
81-
config_helper.update()
104+
config_helper.update(config_path=config_path)
82105

83106

84107
@update.command(help="Update maximum number of tokens")
85108
def max_tokens() -> None:
109+
_max_tokens()
110+
111+
112+
def _max_tokens(config_path: Path = CONFIG_PATH) -> None:
113+
"""Separate function for testing"""
86114
PrintHelper.max_tokens()
87115
config_helper = ConfigHelper().from_file()
88116
config_helper.input_max_token()
89-
config_helper.update()
117+
config_helper.update(config_path=config_path)
90118

91119

92120
@update.command(help="Update temperature")
93121
def temperature() -> None:
122+
_temperature()
123+
124+
125+
def _temperature(config_path: Path = CONFIG_PATH) -> None:
126+
"""Separate function for testing"""
94127
PrintHelper.temperature()
95128
config_helper = ConfigHelper().from_file()
96129
config_helper.input_temperature()
97-
config_helper.update()
130+
config_helper.update(config_path=config_path)
98131

99132

100133
@update.command(help="Update top_p")
101134
def top_p() -> None:
135+
_top_p()
136+
137+
138+
def _top_p(config_path: Path = CONFIG_PATH) -> None:
139+
"""Separate function for testing"""
102140
PrintHelper.top_p()
103141
config_helper = ConfigHelper().from_file()
104142
config_helper.input_top_p()
105-
config_helper.update()
143+
config_helper.update(config_path=config_path)
106144

107145

108146
@update.command(help="Update frequency penalty")
109147
def frequency_penalty() -> None:
148+
_frequency_penalty()
149+
150+
151+
def _frequency_penalty(config_path: Path = CONFIG_PATH) -> None:
152+
"""Separate function for testing"""
110153
PrintHelper.frequency_penalty()
111154
config_helper = ConfigHelper.from_file()
112155
config_helper.input_frequency_penalty()
113-
config_helper.update()
156+
config_helper.update(config_path=config_path)
114157

115158

116159
@update.command(help="Update presence penalty")
117160
def presence_penalty() -> None:
161+
_frequency_penalty()
162+
163+
164+
def _presence_penalty(config_path: Path = CONFIG_PATH) -> None:
165+
"""Separate function for testing"""
118166
PrintHelper.presence_penalty()
119167
config_helper = ConfigHelper.from_file()
120168
config_helper.input_presence_penalty()
121-
config_helper.update()
169+
config_helper.update(config_path=config_path)

askai/entrypoint_init.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
1+
from pathlib import Path
2+
13
import click
24

5+
from .constants import CONFIG_PATH, API_KEY_PATH
36
from .utils import KeyHelper, ConfigHelper, PrintHelper
47

58

69
@click.command()
710
def init() -> None:
811
"""Initialize askai."""
9-
key_helper = KeyHelper()
10-
config_helper = ConfigHelper()
12+
_init()
13+
14+
15+
def _init(config_path: Path = CONFIG_PATH, api_key_path: Path = API_KEY_PATH) -> None:
16+
"""Separate function for testing"""
1117
PrintHelper.logo()
1218

19+
key_helper = KeyHelper()
1320
PrintHelper.key()
1421
key_helper.input()
15-
key_helper.save()
22+
key_helper.save(api_key_path=api_key_path)
1623

17-
config_helper.reset()
24+
config_helper = ConfigHelper()
25+
config_helper.reset(config_path=config_path)
1826

1927
click.echo("Initialization done!")

askai/entrypoint_key.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pathlib import Path
2+
13
import click
24

35
from .utils import KeyHelper, PrintHelper
@@ -13,22 +15,32 @@ def key():
1315
@key.command()
1416
def add() -> None:
1517
"""Add API key"""
18+
_add()
19+
20+
21+
def _add(api_key_path: Path = API_KEY_PATH) -> None:
22+
"""Separate function for testing"""
1623
if API_KEY_PATH.is_file():
1724
PrintHelper.key_exists()
1825

1926
key_helper = KeyHelper()
2027
key_helper.input()
21-
key_helper.save()
28+
key_helper.save(api_key_path=api_key_path)
2229

2330

2431
@key.command()
2532
def remove() -> None:
2633
"""Remove your stored API key"""
34+
_remove()
35+
36+
37+
def _remove(api_key_path: Path = API_KEY_PATH) -> None:
38+
"""Separate function for testing"""
2739
if not API_KEY_PATH.is_file():
2840
PrintHelper.no_key()
2941
else:
3042
user_verification = input("Do you want to remove your API key? [y/Y]? ")
3143
if user_verification.lower() in ["y", "yes"]:
32-
KeyHelper().remove()
44+
KeyHelper().remove(api_key_path=api_key_path)
3345
else:
3446
click.echo("API key not removed.")

askai/utils.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class ConfigHelper:
5151
@classmethod
5252
def from_file(cls, config_path: Path = CONFIG_PATH) -> 'ConfigHelper':
5353
if config_path.is_file():
54-
with open(config_path, "r") as f:
54+
with open(config_path, "r", encoding="utf8") as f:
5555
config = yaml.safe_load(f)
5656
return cls(**config)
5757
else:
@@ -145,15 +145,15 @@ def as_dict(self) -> dict:
145145

146146
def update(self, config_path: Path = CONFIG_PATH) -> None:
147147
config = self.as_dict()
148-
with open(config_path, "w") as f:
148+
with open(config_path, "w", encoding="utf8") as f:
149149
yaml.dump(config, f)
150150

151151
click.echo(click.style("Config updated successfully!", fg="green"))
152152

153153
@staticmethod
154154
def reset(config_path: Path = CONFIG_PATH) -> None:
155155
config = ConfigHelper().as_dict() # Create config with default values
156-
with open(config_path, "w") as f:
156+
with open(config_path, "w", encoding="utf8") as f:
157157
yaml.dump(config, f)
158158

159159
click.echo("\nDefault config has been created with the following values:")
@@ -168,7 +168,7 @@ def show(config_path: Path = CONFIG_PATH) -> None:
168168
click.echo("No config exists. Please reset the config ('askai config reset') "
169169
"or see 'askai config --help'.\n")
170170
else:
171-
with open(config_path, "r") as f:
171+
with open(config_path, "r", encoding="utf8") as f:
172172
try:
173173
config = yaml.safe_load(f)
174174
for key, value in config.items():
@@ -229,9 +229,9 @@ def _input_float(default_value: float,
229229
exit(1)
230230

231231

232-
@dataclass
233232
class KeyHelper:
234-
api_key: str = ""
233+
def __init__(self):
234+
self._api_key: str = ""
235235

236236
def input(self) -> None:
237237
key = getpass("Enter API Key: ")
@@ -245,22 +245,26 @@ def input(self) -> None:
245245
key = getpass("Enter API Key: ")
246246
num_tries += 1
247247

248-
self.api_key = key
248+
self._api_key = key
249249

250-
def save(self) -> None:
251-
API_KEY_PATH.parent.mkdir(parents=True, exist_ok=True)
252-
API_KEY_PATH.write_text(self.api_key)
250+
def save(self, api_key_path: Path = API_KEY_PATH) -> None:
251+
api_key_path.parent.mkdir(parents=True, exist_ok=True)
252+
api_key_path.write_text(self._api_key, encoding="utf8")
253253
click.echo(click.style("Your API key has been successfully added!", fg="green"))
254254

255255
@staticmethod
256-
def remove() -> None:
257-
API_KEY_PATH.unlink()
258-
click.echo(click.style("API key removed.", fg="green"))
256+
def remove(api_key_path: Path = API_KEY_PATH) -> None:
257+
try:
258+
api_key_path.unlink()
259+
click.echo(click.style("API key removed.", fg="green"))
260+
except FileNotFoundError:
261+
click.echo(click.style("No API key found.", fg="red"))
262+
exit()
259263

260-
@classmethod
261-
def from_file(cls) -> str:
262-
if API_KEY_PATH.is_file():
263-
with open(API_KEY_PATH, "r") as f:
264+
@staticmethod
265+
def from_file(api_key_path: Path = API_KEY_PATH) -> str:
266+
if api_key_path.is_file():
267+
with open(api_key_path, "r", encoding="utf8") as f:
264268
api_key = f.read().strip()
265269
return api_key
266270
else:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
version="1.0.3",
77
author="Max Fischer",
88
description="Your simple terminal helper",
9-
long_description=open('README.md').read(),
9+
long_description=open("README.md", encoding="utf8").read(),
1010
long_description_content_type="text/markdown",
1111
license="MIT",
1212
url="https://github.com/maxvfischer/askai",

tests/test_auxiliary.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
from askai.utils import _is_int, _is_float
4+
5+
6+
@pytest.mark.parametrize(
7+
"value, is_int",
8+
[
9+
("1", True),
10+
("a", False),
11+
("0.1", False),
12+
]
13+
)
14+
def test_is_int(value: str, is_int: bool) -> None:
15+
assert _is_int(value) == is_int
16+
17+
18+
@pytest.mark.parametrize(
19+
"value, is_float",
20+
[
21+
("1", True),
22+
("1.0", True),
23+
("a", False),
24+
]
25+
)
26+
def test_is_float(value: str, is_float: bool) -> None:
27+
assert _is_float(value) == is_float

0 commit comments

Comments
 (0)