Skip to content

Commit 55258df

Browse files
committed
fix: clean up code and use pytest
Signed-off-by: Emilio Garcia <[email protected]>
1 parent eabaa28 commit 55258df

File tree

6 files changed

+81
-63
lines changed

6 files changed

+81
-63
lines changed

actions/inference-mock/app.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# Standard
2-
from dataclasses import dataclass
32
import logging
43
import pprint
54

65
# Third Party
76
from completions.completion import create_chat_completion
7+
from config import Config
88
from flask import Flask, request # type: ignore[import-not-found]
99
from matching.matching import Matcher
1010
from werkzeug import exceptions # type: ignore[import-not-found]
1111
import click # type: ignore[import-not-found]
12-
import yaml
1312

1413
# Globals
1514
app = Flask(__name__)
@@ -41,45 +40,35 @@ def completions():
4140
return response
4241

4342

44-
# config
45-
@dataclass
46-
class Config:
47-
matches: list[dict]
48-
port: int = 11434
49-
debug: bool = False
50-
51-
52-
@click.command()
53-
@click.option(
54-
"-c",
55-
"--config",
56-
"config",
57-
type=click.File(mode="r", encoding="utf-8"),
58-
required=True,
59-
help="yaml config file",
60-
)
61-
def start_server(config):
62-
# get config
63-
yaml_data = yaml.safe_load(config)
64-
if not isinstance(yaml_data, dict):
65-
raise ValueError(f"config file {config} must be a set of key-value pairs")
66-
67-
conf = Config(**yaml_data)
68-
43+
def start_server(config: Config):
6944
# configure logger
70-
if conf.debug:
45+
if config.debug:
7146
app.logger.setLevel(logging.DEBUG)
7247
app.logger.debug("debug mode enabled")
7348
else:
7449
app.logger.setLevel(logging.INFO)
7550

7651
# create match strategy object
7752
global strategies # pylint: disable=global-statement
78-
strategies = Matcher(conf.matches)
53+
strategies = Matcher(config.matches)
7954

8055
# init server
81-
app.run(debug=conf.debug, port=conf.port)
56+
app.run(debug=config.debug, port=config.port)
57+
58+
59+
@click.command()
60+
@click.option(
61+
"-c",
62+
"--config",
63+
"config_file",
64+
type=click.File(mode="r", encoding="utf-8"),
65+
required=True,
66+
help="yaml config file",
67+
)
68+
def start_server_cli(config_file):
69+
config = Config.from_file(config_file)
70+
start_server(config)
8271

8372

8473
if __name__ == "__main__":
85-
start_server() # pylint: disable=no-value-for-parameter
74+
start_server_cli() # pylint: disable=no-value-for-parameter

actions/inference-mock/completions/completion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import random
77

88

9-
# TODO: use a library to return and validate completions so this doesn't need to be maintained
109
def create_chat_completion(content: str, model: str = "gpt-3.5") -> dict:
1110
response = {
1211
"id": "chatcmpl-2nYZXNHxx1PeK1u8xXcE1Fqr1U6Ve",

actions/inference-mock/config.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Standard
2+
from dataclasses import dataclass
3+
4+
# Third Party
5+
import yaml
6+
7+
8+
# config
9+
@dataclass
10+
class Config:
11+
debug: bool
12+
port: int
13+
matches: list[dict]
14+
15+
def __init__(self, matches: list[dict], port: int = 11434, debug: bool = False):
16+
if not matches:
17+
raise ValueError("matches must not be empty")
18+
self.port = port
19+
self.debug = debug
20+
self.matches = matches
21+
22+
@staticmethod
23+
def from_file(config_file) -> "Config":
24+
"""
25+
Create a Server instance from a config file.
26+
:param config_file: path to the config file
27+
:return: Server instance
28+
"""
29+
yaml_data = yaml.safe_load(config_file)
30+
if not isinstance(yaml_data, dict):
31+
raise ValueError(
32+
f"config file {config_file} must be a set of key-value pairs"
33+
)
34+
35+
return Config(**yaml_data)

actions/inference-mock/matching/matching.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# Standard
2-
from abc import abstractmethod
3-
from typing import Protocol
2+
from abc import ABC, abstractmethod
43
import pprint
54

65

7-
class Match(Protocol):
6+
class Match(ABC):
87
"""
98
Match represents a single prompt matching
109
strategy. When a match is successful,
@@ -15,10 +14,10 @@ class Match(Protocol):
1514

1615
@abstractmethod
1716
def match(self, prompt: str) -> str | None:
18-
raise NotImplementedError
17+
pass
1918

2019

21-
class Always:
20+
class Always(Match):
2221
"""
2322
Always is a matching strategy that always
2423
is a positive match on a given prompt.
@@ -36,7 +35,7 @@ def match(self, prompt: str) -> str | None:
3635
return None
3736

3837

39-
class Contains:
38+
class Contains(Match):
4039
"""
4140
Contains is a matching strategy that checks
4241
if the prompt string contains all of

actions/inference-mock/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ flask
22
werkzeug
33
click
44
pyyaml
5+
pytest
Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,42 @@
1-
# Standard
2-
import unittest
3-
41
# Third Party
52
from matching.matching import Always, Contains, Matcher, to_match
3+
import pytest
64

75

8-
class TestAlways(unittest.TestCase):
6+
class TestAlways:
97
# match on any prompt
108
def test_always(self):
119
expect_response = "expected response"
1210
prompt = "example prompt"
1311
always = Always(expect_response)
1412
actual_response = always.match(prompt)
15-
self.assertEqual(actual_response, expect_response)
13+
assert actual_response == expect_response
1614

1715
# reject empty prompts
1816
def test_always_empty_prompt(self):
1917
response = "expected response"
2018
prompt = ""
2119
always = Always(response)
2220
actual_response = always.match(prompt)
23-
self.assertIsNone(actual_response)
21+
assert actual_response is None
2422

2523

26-
class TestContains(unittest.TestCase):
24+
class TestContains:
2725
def test_contains(self):
2826
expect_response = "expected response"
2927
prompt = "example prompt"
3028
match_on = ["example"]
3129
contains = Contains(match_on, expect_response)
3230
actual_response = contains.match(prompt)
33-
self.assertEqual(actual_response, expect_response)
31+
assert actual_response == expect_response
3432

3533
def test_contains_many(self):
3634
expect_response = "expected response"
3735
prompt = "a much longer example prompt so we can match on many substring elements of this string"
3836
match_on = ["example", "many substring elements", "match on"]
3937
contains = Contains(match_on, expect_response)
4038
actual_response = contains.match(prompt)
41-
self.assertEqual(actual_response, expect_response)
39+
assert actual_response == expect_response
4240

4341
# if any substrings don't match, return None
4442
def test_contains_mismatch(self):
@@ -47,7 +45,7 @@ def test_contains_mismatch(self):
4745
match_on = ["example", "many substring elements", "match on", "banana"]
4846
contains = Contains(match_on, response)
4947
actual_response = contains.match(prompt)
50-
self.assertIsNone(actual_response)
48+
assert actual_response is None
5149

5250
# reject empty prompts
5351
def test_contains_empty(self):
@@ -56,29 +54,30 @@ def test_contains_empty(self):
5654
match_on = ["example"]
5755
contains = Contains(match_on, response)
5856
actual_response = contains.match(prompt)
59-
self.assertIsNone(actual_response)
57+
assert actual_response is None
6058

6159

62-
class TestMatcher(unittest.TestCase):
60+
class TestMatcher:
6361
def test_to_contains(self):
6462
response = "I am a response"
6563
substr = ["a", "b", "c"]
6664
pattern = {"contains": substr, "response": response}
6765
contains = to_match(pattern)
68-
self.assertIsInstance(contains, Contains)
69-
self.assertEqual(contains.response, response)
66+
assert isinstance(contains, Contains)
67+
assert contains.response == response
7068

7169
def test_to_always(self):
7270
response = "I am a response"
7371
always_pattern = {"response": response}
7472
always = to_match(always_pattern)
75-
self.assertIsInstance(always, Always)
76-
self.assertEqual(always.response, response)
73+
assert isinstance(always, Always)
74+
assert always.response == response
7775

7876
def test_to_invalid(self):
7977
response = "I am a response"
8078
invalid_pattern = {"banana": "foo", "response": response}
81-
self.assertRaises(Exception, to_match, invalid_pattern)
79+
with pytest.raises(TypeError):
80+
to_match(invalid_pattern)
8281

8382
def test_find_match_contains(self):
8483
expect_response = "I am a response"
@@ -88,7 +87,7 @@ def test_find_match_contains(self):
8887

8988
prompt = "example prompt"
9089
actual_response = matcher.find_match(prompt)
91-
self.assertEqual(actual_response, expect_response)
90+
assert actual_response == expect_response
9291

9392
def test_find_match_always(self):
9493
expect_response = "I am a response"
@@ -97,7 +96,7 @@ def test_find_match_always(self):
9796

9897
prompt = "example prompt"
9998
actual_response = matcher.find_match(prompt)
100-
self.assertEqual(actual_response, expect_response)
99+
assert actual_response == expect_response
101100

102101
# test that order is preserved and responses fall back until a match or end of strategies
103102
def test_find_match_fallback(self):
@@ -110,10 +109,6 @@ def test_find_match_fallback(self):
110109
]
111110
matcher = Matcher(patterns)
112111
always_response = matcher.find_match(prompt="example prompt")
113-
self.assertEqual(always_response, "this is the fallback response")
112+
assert always_response == "this is the fallback response"
114113
contains_response = matcher.find_match(prompt="this is the fallback response")
115-
self.assertEqual(contains_response, "a response you will not get")
116-
117-
118-
if __name__ == "__main__":
119-
unittest.main()
114+
assert contains_response == "a response you will not get"

0 commit comments

Comments
 (0)