diff --git a/.github/workflows/inference-mock.yml b/.github/workflows/inference-mock.yml index 8529543..845a97a 100644 --- a/.github/workflows/inference-mock.yml +++ b/.github/workflows/inference-mock.yml @@ -32,4 +32,4 @@ jobs: - name: Run Unit Tests working-directory: actions/inference-mock run: | - python -m unittest test/test.py + pytest test/test.py diff --git a/actions/inference-mock/app.py b/actions/inference-mock/app.py index c957e4a..1da62f8 100644 --- a/actions/inference-mock/app.py +++ b/actions/inference-mock/app.py @@ -1,15 +1,14 @@ # Standard -from dataclasses import dataclass import logging import pprint # Third Party from completions.completion import create_chat_completion +from config import Config from flask import Flask, request # type: ignore[import-not-found] from matching.matching import Matcher from werkzeug import exceptions # type: ignore[import-not-found] import click # type: ignore[import-not-found] -import yaml # Globals app = Flask(__name__) @@ -41,33 +40,9 @@ def completions(): return response -# config -@dataclass -class Config: - matches: list[dict] - port: int = 11434 - debug: bool = False - - -@click.command() -@click.option( - "-c", - "--config", - "config", - type=click.File(mode="r", encoding="utf-8"), - required=True, - help="yaml config file", -) -def start_server(config): - # get config - yaml_data = yaml.safe_load(config) - if not isinstance(yaml_data, dict): - raise ValueError(f"config file {config} must be a set of key-value pairs") - - conf = Config(**yaml_data) - +def start_server(config: Config): # configure logger - if conf.debug: + if config.debug: app.logger.setLevel(logging.DEBUG) app.logger.debug("debug mode enabled") else: @@ -75,11 +50,25 @@ def start_server(config): # create match strategy object global strategies # pylint: disable=global-statement - strategies = Matcher(conf.matches) + strategies = Matcher(config.matches) # init server - app.run(debug=conf.debug, port=conf.port) + app.run(debug=config.debug, port=config.port) + + +@click.command() +@click.option( + "-c", + "--config", + "config_file", + type=click.File(mode="r", encoding="utf-8"), + required=True, + help="yaml config file", +) +def start_server_cli(config_file): + config = Config.from_file(config_file) + start_server(config) if __name__ == "__main__": - start_server() # pylint: disable=no-value-for-parameter + start_server_cli() # pylint: disable=no-value-for-parameter diff --git a/actions/inference-mock/completions/completion.py b/actions/inference-mock/completions/completion.py index b238e03..76eee0a 100644 --- a/actions/inference-mock/completions/completion.py +++ b/actions/inference-mock/completions/completion.py @@ -6,7 +6,6 @@ import random -# TODO: use a library to return and validate completions so this doesn't need to be maintained def create_chat_completion(content: str, model: str = "gpt-3.5") -> dict: response = { "id": "chatcmpl-2nYZXNHxx1PeK1u8xXcE1Fqr1U6Ve", diff --git a/actions/inference-mock/config.py b/actions/inference-mock/config.py new file mode 100644 index 0000000..e5f8017 --- /dev/null +++ b/actions/inference-mock/config.py @@ -0,0 +1,35 @@ +# Standard +from dataclasses import dataclass + +# Third Party +import yaml + + +# config +@dataclass +class Config: + debug: bool + port: int + matches: list[dict] + + def __init__(self, matches: list[dict], port: int = 11434, debug: bool = False): + if not matches: + raise ValueError("matches must not be empty") + self.port = port + self.debug = debug + self.matches = matches + + @staticmethod + def from_file(config_file) -> "Config": + """ + Create a Server instance from a config file. + :param config_file: path to the config file + :return: Server instance + """ + yaml_data = yaml.safe_load(config_file) + if not isinstance(yaml_data, dict): + raise ValueError( + f"config file {config_file} must be a set of key-value pairs" + ) + + return Config(**yaml_data) diff --git a/actions/inference-mock/matching/matching.py b/actions/inference-mock/matching/matching.py index 38efc3d..b54f94c 100644 --- a/actions/inference-mock/matching/matching.py +++ b/actions/inference-mock/matching/matching.py @@ -1,10 +1,9 @@ # Standard -from abc import abstractmethod -from typing import Protocol +from abc import ABC, abstractmethod import pprint -class Match(Protocol): +class Match(ABC): """ Match represents a single prompt matching strategy. When a match is successful, @@ -15,10 +14,10 @@ class Match(Protocol): @abstractmethod def match(self, prompt: str) -> str | None: - raise NotImplementedError + pass -class Always: +class Always(Match): """ Always is a matching strategy that always is a positive match on a given prompt. @@ -36,7 +35,7 @@ def match(self, prompt: str) -> str | None: return None -class Contains: +class Contains(Match): """ Contains is a matching strategy that checks if the prompt string contains all of diff --git a/actions/inference-mock/requirements.txt b/actions/inference-mock/requirements.txt index 20d3485..6d68c11 100644 --- a/actions/inference-mock/requirements.txt +++ b/actions/inference-mock/requirements.txt @@ -2,3 +2,4 @@ flask werkzeug click pyyaml +pytest diff --git a/actions/inference-mock/test/test.py b/actions/inference-mock/test/test.py index 2157a67..6552d23 100644 --- a/actions/inference-mock/test/test.py +++ b/actions/inference-mock/test/test.py @@ -1,18 +1,16 @@ -# Standard -import unittest - # Third Party from matching.matching import Always, Contains, Matcher, to_match +import pytest -class TestAlways(unittest.TestCase): +class TestAlways: # match on any prompt def test_always(self): expect_response = "expected response" prompt = "example prompt" always = Always(expect_response) actual_response = always.match(prompt) - self.assertEqual(actual_response, expect_response) + assert actual_response == expect_response # reject empty prompts def test_always_empty_prompt(self): @@ -20,17 +18,17 @@ def test_always_empty_prompt(self): prompt = "" always = Always(response) actual_response = always.match(prompt) - self.assertIsNone(actual_response) + assert actual_response is None -class TestContains(unittest.TestCase): +class TestContains: def test_contains(self): expect_response = "expected response" prompt = "example prompt" match_on = ["example"] contains = Contains(match_on, expect_response) actual_response = contains.match(prompt) - self.assertEqual(actual_response, expect_response) + assert actual_response == expect_response def test_contains_many(self): expect_response = "expected response" @@ -38,7 +36,7 @@ def test_contains_many(self): match_on = ["example", "many substring elements", "match on"] contains = Contains(match_on, expect_response) actual_response = contains.match(prompt) - self.assertEqual(actual_response, expect_response) + assert actual_response == expect_response # if any substrings don't match, return None def test_contains_mismatch(self): @@ -47,7 +45,7 @@ def test_contains_mismatch(self): match_on = ["example", "many substring elements", "match on", "banana"] contains = Contains(match_on, response) actual_response = contains.match(prompt) - self.assertIsNone(actual_response) + assert actual_response is None # reject empty prompts def test_contains_empty(self): @@ -56,29 +54,30 @@ def test_contains_empty(self): match_on = ["example"] contains = Contains(match_on, response) actual_response = contains.match(prompt) - self.assertIsNone(actual_response) + assert actual_response is None -class TestMatcher(unittest.TestCase): +class TestMatcher: def test_to_contains(self): response = "I am a response" substr = ["a", "b", "c"] pattern = {"contains": substr, "response": response} contains = to_match(pattern) - self.assertIsInstance(contains, Contains) - self.assertEqual(contains.response, response) + assert isinstance(contains, Contains) + assert contains.response == response def test_to_always(self): response = "I am a response" always_pattern = {"response": response} always = to_match(always_pattern) - self.assertIsInstance(always, Always) - self.assertEqual(always.response, response) + assert isinstance(always, Always) + assert always.response == response def test_to_invalid(self): response = "I am a response" invalid_pattern = {"banana": "foo", "response": response} - self.assertRaises(Exception, to_match, invalid_pattern) + with pytest.raises(TypeError): + to_match(invalid_pattern) def test_find_match_contains(self): expect_response = "I am a response" @@ -88,7 +87,7 @@ def test_find_match_contains(self): prompt = "example prompt" actual_response = matcher.find_match(prompt) - self.assertEqual(actual_response, expect_response) + assert actual_response == expect_response def test_find_match_always(self): expect_response = "I am a response" @@ -97,7 +96,7 @@ def test_find_match_always(self): prompt = "example prompt" actual_response = matcher.find_match(prompt) - self.assertEqual(actual_response, expect_response) + assert actual_response == expect_response # test that order is preserved and responses fall back until a match or end of strategies def test_find_match_fallback(self): @@ -110,10 +109,6 @@ def test_find_match_fallback(self): ] matcher = Matcher(patterns) always_response = matcher.find_match(prompt="example prompt") - self.assertEqual(always_response, "this is the fallback response") + assert always_response == "this is the fallback response" contains_response = matcher.find_match(prompt="this is the fallback response") - self.assertEqual(contains_response, "a response you will not get") - - -if __name__ == "__main__": - unittest.main() + assert contains_response == "a response you will not get"