Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/inference-mock.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ jobs:
- name: Run Unit Tests
working-directory: actions/inference-mock
run: |
python -m unittest test/test.py
pytest test/test.py
51 changes: 20 additions & 31 deletions actions/inference-mock/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Standard
from dataclasses import dataclass
import logging
import pprint

# Third Party
from completions.completion import create_chat_completion
from config import Config
from flask import Flask, request # type: ignore[import-not-found]
from matching.matching import Matcher
from werkzeug import exceptions # type: ignore[import-not-found]
import click # type: ignore[import-not-found]
import yaml

# Globals
app = Flask(__name__)
Expand Down Expand Up @@ -41,45 +40,35 @@ def completions():
return response


# config
@dataclass
class Config:
matches: list[dict]
port: int = 11434
debug: bool = False


@click.command()
@click.option(
"-c",
"--config",
"config",
type=click.File(mode="r", encoding="utf-8"),
required=True,
help="yaml config file",
)
def start_server(config):
# get config
yaml_data = yaml.safe_load(config)
if not isinstance(yaml_data, dict):
raise ValueError(f"config file {config} must be a set of key-value pairs")

conf = Config(**yaml_data)

def start_server(config: Config):
# configure logger
if conf.debug:
if config.debug:
app.logger.setLevel(logging.DEBUG)
app.logger.debug("debug mode enabled")
else:
app.logger.setLevel(logging.INFO)

# create match strategy object
global strategies # pylint: disable=global-statement
strategies = Matcher(conf.matches)
strategies = Matcher(config.matches)

# init server
app.run(debug=conf.debug, port=conf.port)
app.run(debug=config.debug, port=config.port)


@click.command()
@click.option(
"-c",
"--config",
"config_file",
type=click.File(mode="r", encoding="utf-8"),
required=True,
help="yaml config file",
)
def start_server_cli(config_file):
config = Config.from_file(config_file)
start_server(config)


if __name__ == "__main__":
start_server() # pylint: disable=no-value-for-parameter
start_server_cli() # pylint: disable=no-value-for-parameter
1 change: 0 additions & 1 deletion actions/inference-mock/completions/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import random


# TODO: use a library to return and validate completions so this doesn't need to be maintained
def create_chat_completion(content: str, model: str = "gpt-3.5") -> dict:
response = {
"id": "chatcmpl-2nYZXNHxx1PeK1u8xXcE1Fqr1U6Ve",
Expand Down
35 changes: 35 additions & 0 deletions actions/inference-mock/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Standard
from dataclasses import dataclass

# Third Party
import yaml


# config
@dataclass
class Config:
debug: bool
port: int
matches: list[dict]

def __init__(self, matches: list[dict], port: int = 11434, debug: bool = False):
if not matches:
raise ValueError("matches must not be empty")
self.port = port
self.debug = debug
self.matches = matches

@staticmethod
def from_file(config_file) -> "Config":
"""
Create a Server instance from a config file.
:param config_file: path to the config file
:return: Server instance
"""
yaml_data = yaml.safe_load(config_file)
if not isinstance(yaml_data, dict):
raise ValueError(
f"config file {config_file} must be a set of key-value pairs"
)

return Config(**yaml_data)
11 changes: 5 additions & 6 deletions actions/inference-mock/matching/matching.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# Standard
from abc import abstractmethod
from typing import Protocol
from abc import ABC, abstractmethod
import pprint


class Match(Protocol):
class Match(ABC):
"""
Match represents a single prompt matching
strategy. When a match is successful,
Expand All @@ -15,10 +14,10 @@ class Match(Protocol):

@abstractmethod
def match(self, prompt: str) -> str | None:
raise NotImplementedError
pass


class Always:
class Always(Match):
"""
Always is a matching strategy that always
is a positive match on a given prompt.
Expand All @@ -36,7 +35,7 @@ def match(self, prompt: str) -> str | None:
return None


class Contains:
class Contains(Match):
"""
Contains is a matching strategy that checks
if the prompt string contains all of
Expand Down
1 change: 1 addition & 0 deletions actions/inference-mock/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ flask
werkzeug
click
pyyaml
pytest
45 changes: 20 additions & 25 deletions actions/inference-mock/test/test.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,42 @@
# Standard
import unittest

# Third Party
from matching.matching import Always, Contains, Matcher, to_match
import pytest


class TestAlways(unittest.TestCase):
class TestAlways:
# match on any prompt
def test_always(self):
expect_response = "expected response"
prompt = "example prompt"
always = Always(expect_response)
actual_response = always.match(prompt)
self.assertEqual(actual_response, expect_response)
assert actual_response == expect_response

# reject empty prompts
def test_always_empty_prompt(self):
response = "expected response"
prompt = ""
always = Always(response)
actual_response = always.match(prompt)
self.assertIsNone(actual_response)
assert actual_response is None


class TestContains(unittest.TestCase):
class TestContains:
def test_contains(self):
expect_response = "expected response"
prompt = "example prompt"
match_on = ["example"]
contains = Contains(match_on, expect_response)
actual_response = contains.match(prompt)
self.assertEqual(actual_response, expect_response)
assert actual_response == expect_response

def test_contains_many(self):
expect_response = "expected response"
prompt = "a much longer example prompt so we can match on many substring elements of this string"
match_on = ["example", "many substring elements", "match on"]
contains = Contains(match_on, expect_response)
actual_response = contains.match(prompt)
self.assertEqual(actual_response, expect_response)
assert actual_response == expect_response

# if any substrings don't match, return None
def test_contains_mismatch(self):
Expand All @@ -47,7 +45,7 @@ def test_contains_mismatch(self):
match_on = ["example", "many substring elements", "match on", "banana"]
contains = Contains(match_on, response)
actual_response = contains.match(prompt)
self.assertIsNone(actual_response)
assert actual_response is None

# reject empty prompts
def test_contains_empty(self):
Expand All @@ -56,29 +54,30 @@ def test_contains_empty(self):
match_on = ["example"]
contains = Contains(match_on, response)
actual_response = contains.match(prompt)
self.assertIsNone(actual_response)
assert actual_response is None


class TestMatcher(unittest.TestCase):
class TestMatcher:
def test_to_contains(self):
response = "I am a response"
substr = ["a", "b", "c"]
pattern = {"contains": substr, "response": response}
contains = to_match(pattern)
self.assertIsInstance(contains, Contains)
self.assertEqual(contains.response, response)
assert isinstance(contains, Contains)
assert contains.response == response

def test_to_always(self):
response = "I am a response"
always_pattern = {"response": response}
always = to_match(always_pattern)
self.assertIsInstance(always, Always)
self.assertEqual(always.response, response)
assert isinstance(always, Always)
assert always.response == response

def test_to_invalid(self):
response = "I am a response"
invalid_pattern = {"banana": "foo", "response": response}
self.assertRaises(Exception, to_match, invalid_pattern)
with pytest.raises(TypeError):
to_match(invalid_pattern)

def test_find_match_contains(self):
expect_response = "I am a response"
Expand All @@ -88,7 +87,7 @@ def test_find_match_contains(self):

prompt = "example prompt"
actual_response = matcher.find_match(prompt)
self.assertEqual(actual_response, expect_response)
assert actual_response == expect_response

def test_find_match_always(self):
expect_response = "I am a response"
Expand All @@ -97,7 +96,7 @@ def test_find_match_always(self):

prompt = "example prompt"
actual_response = matcher.find_match(prompt)
self.assertEqual(actual_response, expect_response)
assert actual_response == expect_response

# test that order is preserved and responses fall back until a match or end of strategies
def test_find_match_fallback(self):
Expand All @@ -110,10 +109,6 @@ def test_find_match_fallback(self):
]
matcher = Matcher(patterns)
always_response = matcher.find_match(prompt="example prompt")
self.assertEqual(always_response, "this is the fallback response")
assert always_response == "this is the fallback response"
contains_response = matcher.find_match(prompt="this is the fallback response")
self.assertEqual(contains_response, "a response you will not get")


if __name__ == "__main__":
unittest.main()
assert contains_response == "a response you will not get"