diff --git a/.github/workflows/inference-mock.yml b/.github/workflows/inference-mock.yml
new file mode 100644
index 0000000..8529543
--- /dev/null
+++ b/.github/workflows/inference-mock.yml
@@ -0,0 +1,35 @@
+name: Inference Mock Tests
+
+on:
+ pull_request:
+ branches:
+ - "main"
+ - "release-**"
+ paths:
+ - 'actions/inference-mock/**'
+ - '.github/workflows/inference-mock.yml' # This workflow
+
+jobs:
+ inference-mock-unit-tests:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.11"]
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ sparse-checkout: |
+ actions/inference-mock
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ working-directory: actions/inference-mock
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
+ - name: Run Unit Tests
+ working-directory: actions/inference-mock
+ run: |
+ python -m unittest test/test.py
diff --git a/.pylintrc b/.pylintrc
index 22c302b..23f78a1 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -627,7 +627,7 @@ missing-member-max-choices=1
mixin-class-rgx=.*[Mm]ixin
# List of decorators that change the signature of a decorated function.
-signature-mutators=unittest.mock.patch,unittest.mock.patch.object
+signature-mutators=unittest.mock.patch,unittest.mock.patch.object,click.decorators.option
[VARIABLES]
diff --git a/README.md b/README.md
index 07aa875..5beade3 100644
--- a/README.md
+++ b/README.md
@@ -21,6 +21,7 @@ Below is a list of the in-house GitHub actions stored in this repository:
| [launch-ec2-runner-with-fallback](./actions/launch-ec2-runner-with-fallback/launch-ec2-runner-with-fallback.md) | Used launch an EC2 instance in AWS, either as a spot instance or a dedicated instance. If your preferred availability zone lacks availability for your instance type, "backup" availability zones will be tried. |
- Insufficient capacity in AWS (i.e., AWS lacks availablility for your desired EC2 instance type in your preferred availability zone)
- Cost savings (i.e., You want to try launching your EC2 runner as a spot instance first)
|
| [validate-notebooks](./actions/launch-ec2-runner-with-fallback/launch-ec2-runner-with-fallback.md) | Used to validate `.ipynb` files | - I maintain a collection of `.ipynb` files and run ci jobs to test them. I would like to quickly verify that the files are formatted correctly before spinning up more complex or expensive CI jobs to test those notebooks.
| [update-constraints](./actions/update-constraints/update-constraints.md) | Used to update `constraints-dev.txt` file | - I pin all project dependencies in CI using `constraints-dev.txt` file. I would like to monitor new dependency releases and periodically post PRs to move pins forward.
+| [inference-mock](./actions/inference-mock/README.md) | Used to mock LLM calls | - I have a notebook that I want to test that makes an LLM call, but I don't need to rely heavily on an LLM and don't need to run a real inference server for this test.
|
## ❓ How to Use One or More In-House GitHub Actions
diff --git a/actions/inference-mock/README.md b/actions/inference-mock/README.md
new file mode 100644
index 0000000..547aa04
--- /dev/null
+++ b/actions/inference-mock/README.md
@@ -0,0 +1,62 @@
+# Inference Mock
+
+## Overview
+
+Inference Mock is a tool that creates a flask server that runs as a background process. OpenAI comptabile calls can be made to its completions API.
+Based on how the server is configured, it will send a set of programmed responses back.
+
+## When to Use it?
+
+Testing notebooks is difficult to do since you often don't write functions or unit tests in them. Instead, if you want to mock an LLM call and response,
+this is an easy way to rig that up in your testing environment. This is best used for integration, unit, and smoke tests. This is obviously not a real
+inference service, so its best used for testing code that makes occasional calls to an LLM to do a task.
+
+## Usage
+
+This is a reusable workflow, and can be referenced and used in any github actions workflow. First, you will need to make a config file. You can set the following fields:
+
+```yaml
+# debug: enable debug logging and debug mode in flask
+# optional: this defaults to False
+debug: True
+
+# port: the port the server will listen on
+# optional: this defaults to 11434
+port: 11434
+
+# matches: a list of matching strategies for expected sets of prompt response pairs. The following strategies are available:
+# - contains: accepts a list of substrings. All incoming prompts will need to contain all listed substrings for this match to be positive
+# - response: passing only a response is an `Always` match strategy. If no other strategy has matched yet, this will always be a positive match.
+#
+# note: the strategies are executed in the order listed, and the first succesful match is accepted. If you start with an `Always` strategy, its
+# response will be the only response returned.
+matches:
+
+ # this is an example of a `contains` strategy. If the prompt contains the substrings, it returns the response.
+ - contains:
+ - 'I need you to generate three questions that must be answered only with information contained in this passage, and nothing else.'
+ response: '{"fact_single": "What are some common ways to assign rewards to partial answers?", "fact_single_answer": "There are three: prod, which takes the product of rewards across all steps; min, which selects the minimum reward over all steps; and last, which uses the reward from the final step.", "reasoning": "What is the best method for rewarding models?", "reasoning_answer": "That depends on whether the training data is prepared using MC rollout, human annotation, or model annotation.", "summary": "How does QWEN implement model reward?", "summary_answer": "Qwen computes the aggregate reward based on the entire partial reward trajectory. I also uses a method that feeds the performance reference model with partial answers, then only considering the final reward token."}'
+
+ # this is an example of an `Always` strategy. It will always match, and return this response.
+ - response: "hi I am the default response"
+```
+
+This config must be passed to this action as an input. Here is an example of a workflow that invokes this action to create a mock server.
+
+```yaml
+jobs:
+ example-job:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout "inference-mock" in-house CI action
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ repository: instructlab/ci-actions
+ path: ci-actions
+ sparse-checkout: |
+ actions/inference-mock
+ - name: Inference Mock
+ uses: ./ci-actions/actions/inference-mock
+ with:
+ config: "example-config.yml"
+```
diff --git a/actions/inference-mock/action.yml b/actions/inference-mock/action.yml
new file mode 100644
index 0000000..4c7b58b
--- /dev/null
+++ b/actions/inference-mock/action.yml
@@ -0,0 +1,21 @@
+name: 'Inference Mock'
+description: 'Creates and runs a server that returns mock Open AI completions as a background process'
+author: "InstructLab"
+
+inputs:
+ config:
+ description: the path to a config.yml file for the inference mock server
+ required: true
+ type: string
+
+runs:
+ using: "composite"
+ steps:
+ - name: Install Dependencies
+ shell: bash
+ run: pip install -r ${{ github.action_path }}/requirements.txt
+ - name: Run Inference Mock Server
+ shell: bash
+ run: |
+ nohup python ${{ github.action_path }}/app.py --config ${{ inputs.config }} &
+ sleep 2
diff --git a/actions/inference-mock/app.py b/actions/inference-mock/app.py
new file mode 100644
index 0000000..c957e4a
--- /dev/null
+++ b/actions/inference-mock/app.py
@@ -0,0 +1,85 @@
+# Standard
+from dataclasses import dataclass
+import logging
+import pprint
+
+# Third Party
+from completions.completion import create_chat_completion
+from flask import Flask, request # type: ignore[import-not-found]
+from matching.matching import Matcher
+from werkzeug import exceptions # type: ignore[import-not-found]
+import click # type: ignore[import-not-found]
+import yaml
+
+# Globals
+app = Flask(__name__)
+strategies: Matcher # a read only list of matching strategies
+
+
+# Routes
+@app.route("/v1/completions", methods=["POST"])
+def completions():
+ data = request.get_json()
+ if not data or "prompt" not in data:
+ raise exceptions.BadRequest("prompt is empty or None")
+
+ prompt = data.get("prompt")
+ prompt_debug_str = prompt
+ if len(prompt) > 90:
+ prompt_debug_str = data["prompt"][:90] + "..."
+
+ app.logger.debug(
+ f"{request.method} {request.url} {data['model']} {prompt_debug_str}"
+ )
+
+ chat_response = strategies.find_match(
+ prompt
+ ) # handle prompt and generate correct response
+
+ response = create_chat_completion(chat_response, model=data.get("model"))
+ app.logger.debug(f"response: {pprint.pformat(response, compact=True)}")
+ return response
+
+
+# config
+@dataclass
+class Config:
+ matches: list[dict]
+ port: int = 11434
+ debug: bool = False
+
+
+@click.command()
+@click.option(
+ "-c",
+ "--config",
+ "config",
+ type=click.File(mode="r", encoding="utf-8"),
+ required=True,
+ help="yaml config file",
+)
+def start_server(config):
+ # get config
+ yaml_data = yaml.safe_load(config)
+ if not isinstance(yaml_data, dict):
+ raise ValueError(f"config file {config} must be a set of key-value pairs")
+
+ conf = Config(**yaml_data)
+
+ # configure logger
+ if conf.debug:
+ app.logger.setLevel(logging.DEBUG)
+ app.logger.debug("debug mode enabled")
+ else:
+ app.logger.setLevel(logging.INFO)
+
+ # create match strategy object
+ global strategies # pylint: disable=global-statement
+ strategies = Matcher(conf.matches)
+
+ # init server
+ app.run(debug=conf.debug, port=conf.port)
+
+
+if __name__ == "__main__":
+ start_server() # pylint: disable=no-value-for-parameter
diff --git a/actions/inference-mock/completions/__init__.py b/actions/inference-mock/completions/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/actions/inference-mock/completions/completion.py b/actions/inference-mock/completions/completion.py
new file mode 100644
index 0000000..b238e03
--- /dev/null
+++ b/actions/inference-mock/completions/completion.py
@@ -0,0 +1,33 @@
+# mock openAI completion responses
+# credit: https://github.com/openai/openai-python/issues/715#issuecomment-1809203346
+# License: MIT
+
+# Standard
+import random
+
+
+# TODO: use a library to return and validate completions so this doesn't need to be maintained
+def create_chat_completion(content: str, model: str = "gpt-3.5") -> dict:
+ response = {
+ "id": "chatcmpl-2nYZXNHxx1PeK1u8xXcE1Fqr1U6Ve",
+ "object": "chat.completion",
+ "created": "12345678",
+ "model": model,
+ "system_fingerprint": "fp_44709d6fcb",
+ "choices": [
+ {
+ "text": content,
+ "content": content,
+ "index": 0,
+ "logprobs": None,
+ "finish_reason": "length",
+ },
+ ],
+ "usage": {
+ "prompt_tokens": random.randint(10, 500),
+ "completion_tokens": random.randint(10, 500),
+ "total_tokens": random.randint(10, 500),
+ },
+ }
+
+ return response
diff --git a/actions/inference-mock/matching/__init__.py b/actions/inference-mock/matching/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/actions/inference-mock/matching/matching.py b/actions/inference-mock/matching/matching.py
new file mode 100644
index 0000000..38efc3d
--- /dev/null
+++ b/actions/inference-mock/matching/matching.py
@@ -0,0 +1,99 @@
+# Standard
+from abc import abstractmethod
+from typing import Protocol
+import pprint
+
+
+class Match(Protocol):
+ """
+ Match represents a single prompt matching
+ strategy. When a match is successful,
+ the response is what should be returned.
+ """
+
+ response: str
+
+ @abstractmethod
+ def match(self, prompt: str) -> str | None:
+ raise NotImplementedError
+
+
+class Always:
+ """
+ Always is a matching strategy that always
+ is a positive match on a given prompt.
+
+ This is best used when only one prompt response
+ is expected.
+ """
+
+ def __init__(self, response: str):
+ self.response = response
+
+ def match(self, prompt: str) -> str | None:
+ if prompt:
+ return self.response
+ return None
+
+
+class Contains:
+ """
+ Contains is a matching strategy that checks
+ if the prompt string contains all of
+ the substrings in the `contains` attribute.
+ """
+
+ contains: list[str]
+
+ def __init__(self, contains: list[str], response: str):
+ if not contains or len(contains) == 0:
+ raise ValueError("contains must not be empty")
+ self.response = response
+ self.contains = contains
+
+ def match(self, prompt: str) -> str | None:
+ if not prompt:
+ return None
+ for context in self.contains:
+ if context not in prompt:
+ return None
+
+ return self.response
+
+
+# helper function pulled out for easier testing
+def to_match(pattern: dict) -> Match:
+ response = pattern.get("response")
+ if not response:
+ raise ValueError(
+ f"matching strategy must have a response: {pprint.pformat(pattern, compact=True)}"
+ )
+ if "contains" in pattern:
+ return Contains(**pattern)
+ return Always(**pattern)
+
+
+class Matcher:
+ """
+ Matcher matches prompt context and then
+ selects a user provided reply.
+ """
+
+ strategies: list[Match]
+
+ def __init__(self, matching_patterns: list[dict]):
+ if not matching_patterns:
+ raise ValueError(
+ "matching strategies must contain at least one Match strategy"
+ )
+
+ self.strategies: list[Match] = []
+ for matching_pattern in matching_patterns:
+ self.strategies.append(to_match(matching_pattern))
+
+ def find_match(self, prompt: str) -> str:
+ for strategy in self.strategies:
+ response = strategy.match(prompt)
+ if response:
+ return response
+ return ""
diff --git a/actions/inference-mock/requirements.txt b/actions/inference-mock/requirements.txt
new file mode 100644
index 0000000..20d3485
--- /dev/null
+++ b/actions/inference-mock/requirements.txt
@@ -0,0 +1,4 @@
+flask
+werkzeug
+click
+pyyaml
diff --git a/actions/inference-mock/test/__init__.py b/actions/inference-mock/test/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/actions/inference-mock/test/test.py b/actions/inference-mock/test/test.py
new file mode 100644
index 0000000..2157a67
--- /dev/null
+++ b/actions/inference-mock/test/test.py
@@ -0,0 +1,119 @@
+# Standard
+import unittest
+
+# Third Party
+from matching.matching import Always, Contains, Matcher, to_match
+
+
+class TestAlways(unittest.TestCase):
+ # match on any prompt
+ def test_always(self):
+ expect_response = "expected response"
+ prompt = "example prompt"
+ always = Always(expect_response)
+ actual_response = always.match(prompt)
+ self.assertEqual(actual_response, expect_response)
+
+ # reject empty prompts
+ def test_always_empty_prompt(self):
+ response = "expected response"
+ prompt = ""
+ always = Always(response)
+ actual_response = always.match(prompt)
+ self.assertIsNone(actual_response)
+
+
+class TestContains(unittest.TestCase):
+ def test_contains(self):
+ expect_response = "expected response"
+ prompt = "example prompt"
+ match_on = ["example"]
+ contains = Contains(match_on, expect_response)
+ actual_response = contains.match(prompt)
+ self.assertEqual(actual_response, expect_response)
+
+ def test_contains_many(self):
+ expect_response = "expected response"
+ prompt = "a much longer example prompt so we can match on many substring elements of this string"
+ match_on = ["example", "many substring elements", "match on"]
+ contains = Contains(match_on, expect_response)
+ actual_response = contains.match(prompt)
+ self.assertEqual(actual_response, expect_response)
+
+ # if any substrings don't match, return None
+ def test_contains_mismatch(self):
+ response = "expected response"
+ prompt = "a much longer example prompt so we can match on many substring elements of this string"
+ match_on = ["example", "many substring elements", "match on", "banana"]
+ contains = Contains(match_on, response)
+ actual_response = contains.match(prompt)
+ self.assertIsNone(actual_response)
+
+ # reject empty prompts
+ def test_contains_empty(self):
+ response = "expected response"
+ prompt = ""
+ match_on = ["example"]
+ contains = Contains(match_on, response)
+ actual_response = contains.match(prompt)
+ self.assertIsNone(actual_response)
+
+
+class TestMatcher(unittest.TestCase):
+ def test_to_contains(self):
+ response = "I am a response"
+ substr = ["a", "b", "c"]
+ pattern = {"contains": substr, "response": response}
+ contains = to_match(pattern)
+ self.assertIsInstance(contains, Contains)
+ self.assertEqual(contains.response, response)
+
+ def test_to_always(self):
+ response = "I am a response"
+ always_pattern = {"response": response}
+ always = to_match(always_pattern)
+ self.assertIsInstance(always, Always)
+ self.assertEqual(always.response, response)
+
+ def test_to_invalid(self):
+ response = "I am a response"
+ invalid_pattern = {"banana": "foo", "response": response}
+ self.assertRaises(Exception, to_match, invalid_pattern)
+
+ def test_find_match_contains(self):
+ expect_response = "I am a response"
+ substr = ["example", "p"]
+ patterns = [{"contains": substr, "response": expect_response}]
+ matcher = Matcher(patterns)
+
+ prompt = "example prompt"
+ actual_response = matcher.find_match(prompt)
+ self.assertEqual(actual_response, expect_response)
+
+ def test_find_match_always(self):
+ expect_response = "I am a response"
+ patterns = [{"response": expect_response}]
+ matcher = Matcher(patterns)
+
+ prompt = "example prompt"
+ actual_response = matcher.find_match(prompt)
+ self.assertEqual(actual_response, expect_response)
+
+ # test that order is preserved and responses fall back until a match or end of strategies
+ def test_find_match_fallback(self):
+ patterns = [
+ {
+ "contains": ["this is the fallback response"],
+ "response": "a response you will not get",
+ },
+ {"response": "this is the fallback response"},
+ ]
+ matcher = Matcher(patterns)
+ always_response = matcher.find_match(prompt="example prompt")
+ self.assertEqual(always_response, "this is the fallback response")
+ contains_response = matcher.find_match(prompt="this is the fallback response")
+ self.assertEqual(contains_response, "a response you will not get")
+
+
+if __name__ == "__main__":
+ unittest.main()