Skip to content

Commit 03d93b4

Browse files
Adding basic runtime tests and gh workflow (#93)
* adding conftest.py for test configs * adding modified tests with optional LLM runs * adding llama 1b in model ids * adding tests to config and workflow * renaming workflow * small changes * trying to test workflow * adding the tests to the quality workflow * adding env variable to disable tests * chaning marker name llm -> qualitative * changing test markers * addressing PR comments * changing ollama port * changing ollama port * changing ollama order * skipping hf tests till we have a 1b alora * skip rich doc test that takes too much memory * remove unused session functions * changing env var name * minor changes * ignoring more watsonx for now * minor changes * fix non-duplicate member func for mify in python 3.11 --------- Co-authored-by: Jake LoRocco <[email protected]>
1 parent 4819407 commit 03d93b4

File tree

16 files changed

+232
-142
lines changed

16 files changed

+232
-142
lines changed

.github/workflows/quality.yml

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ jobs:
1717
strategy:
1818
matrix:
1919
python-version: ['3.10', '3.11', '3.12'] # Need to add 3.13 once we resolve outlines issues.
20+
env:
21+
CICD: 1
22+
OLLAMA_HOST: "127.0.0.1:5000"
2023
steps:
2124
- uses: actions/checkout@v4
2225
- name: Install uv and set the python version
@@ -31,9 +34,22 @@ jobs:
3134
path: ~/.cache/pre-commit
3235
key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }}
3336
- name: Install dependencies
34-
run: uv sync --frozen --all-extras
37+
run: uv sync --frozen --all-extras --group dev
3538
- name: Check style and run tests
3639
run: pre-commit run --all-files
37-
- name: Send failure message
40+
- name: Send failure message pre-commit
3841
if: failure() # This step will only run if a previous step failed
3942
run: echo "The quality verification failed. Please run precommit "
43+
- name: Install Ollama
44+
run: curl -fsSL https://ollama.com/install.sh | sh
45+
- name: Start serving ollama
46+
run: nohup ollama serve &
47+
- name: Pull Llama 3.2:1b model
48+
run: ollama pull llama3.2:1b
49+
50+
- name: Run Tests
51+
run: uv run -m pytest -v test
52+
- name: Send failure message tests
53+
if: failure() # This step will only run if a previous step failed
54+
run: echo "Tests failed. Please verify that tests are working locally."
55+

mellea/backends/model_ids.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ class ModelIdentifier:
8989
ollama_name="llama-guard3:1b", hf_model_name="unsloth/Llama-Guard-3-1B"
9090
)
9191

92+
META_LLAMA_3_2_1B = ModelIdentifier(
93+
ollama_name="llama3.2:1b", hf_model_name="unsloth/Llama-3.2-1B"
94+
)
95+
9296
########################
9397
#### Mistral models ####
9498
########################

mellea/stdlib/mify.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def _get_all_fields(self) -> dict[str, Any]:
132132
if self._fields_exclude:
133133
fields_exclude = self._fields_exclude
134134

135-
# This includes fields defined by any superclasses, as long as it's not object.
136-
all_fields = _get_non_duplicate_fields(self, object)
135+
# This includes fields defined by any superclasses, as long as it's not Protocol.
136+
all_fields = _get_non_duplicate_fields(self, Protocol)
137137

138138
# It does matter if include is an empty set. Handle it's cases here.
139139
if self._fields_include is not None:
@@ -366,18 +366,15 @@ def mification(obj: T) -> T:
366366

367367

368368
def _get_non_duplicate_members(
369-
object: object, check_duplicates: object
369+
obj: object, check_duplicates: object
370370
) -> dict[str, Callable]:
371371
"""Returns all methods/functions unique to the object."""
372372
members = dict(
373373
inspect.getmembers(
374-
object,
374+
obj,
375375
# Checks for ismethod or isfunction because of the methods added from the MifiedProtocol.
376-
predicate=lambda x: inspect.ismethod(x)
377-
or (
378-
inspect.isfunction(x)
379-
and x.__name__ not in dict(inspect.getmembers(check_duplicates)).keys()
380-
),
376+
predicate=lambda x: (inspect.ismethod(x) or inspect.isfunction(x))
377+
and x.__name__ not in dict(inspect.getmembers(check_duplicates)).keys(),
381378
)
382379
)
383380
return members

mellea/stdlib/session.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -414,31 +414,6 @@ def validate(
414414

415415
return rvs
416416

417-
def req(self, *args, **kwargs):
418-
"""Shorthand for Requirement.__init__(...)."""
419-
return req(*args, **kwargs)
420-
421-
def check(self, *args, **kwargs):
422-
"""Shorthand for Requirement.__init__(..., check_only=True)."""
423-
return check(*args, **kwargs)
424-
425-
def load_default_aloras(self):
426-
"""Loads the default Aloras for this model, if they exist and if the backend supports."""
427-
from mellea.backends.huggingface import LocalHFBackend
428-
429-
if self.backend.model_id == IBM_GRANITE_3_2_8B and isinstance(
430-
self.backend, LocalHFBackend
431-
):
432-
from mellea.backends.aloras.huggingface.granite_aloras import (
433-
add_granite_aloras,
434-
)
435-
436-
add_granite_aloras(self.backend)
437-
return
438-
self._session_logger.warning(
439-
"This model/backend combination does not support any aloras."
440-
)
441-
442417
def genslot(
443418
self,
444419
gen_slot: Component,

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,8 @@ skip = 'requirements.txt,uv.lock'
161161
[tool.mypy]
162162
disable_error_code = ["empty-body", "import-untyped"]
163163
python_version = "3.10"
164+
165+
[tool.pytest.ini_options]
166+
markers = [
167+
"qualitative: Marks the test as needing an exact output from an LLM; set by an ENV variable for CICD. All tests marked with this will xfail in CI/CD"
168+
]

test/backends/test_huggingface.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ def session(backend):
3737
yield session
3838
session.reset()
3939

40-
40+
@pytest.mark.qualitative
4141
def test_system_prompt(session):
4242
result = session.chat(
4343
"Where are we going?",
4444
model_options={ModelOption.SYSTEM_PROMPT: "Talk like a pirate."},
4545
)
4646
print(result)
4747

48-
48+
@pytest.mark.qualitative
4949
def test_constraint_alora(session, backend):
5050
answer = session.instruct(
5151
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.",
@@ -63,7 +63,7 @@ def test_constraint_alora(session, backend):
6363
)
6464
assert alora_output in ["Y", "N"], alora_output
6565

66-
66+
@pytest.mark.qualitative
6767
def test_constraint_lora_with_requirement(session, backend):
6868
answer = session.instruct(
6969
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
@@ -80,6 +80,7 @@ def test_constraint_lora_with_requirement(session, backend):
8080
assert str(val_result.reason) in ["Y", "N"]
8181

8282

83+
@pytest.mark.qualitative
8384
def test_constraint_lora_override(session, backend):
8485
backend.default_to_constraint_checking_alora = False # type: ignore
8586
answer = session.instruct(
@@ -95,6 +96,7 @@ def test_constraint_lora_override(session, backend):
9596
backend.default_to_constraint_checking_alora = True
9697

9798

99+
@pytest.mark.qualitative
98100
def test_constraint_lora_override_does_not_override_alora(session, backend):
99101
backend.default_to_constraint_checking_alora = False # type: ignore
100102
answer = session.instruct(
@@ -112,6 +114,7 @@ def test_constraint_lora_override_does_not_override_alora(session, backend):
112114
backend.default_to_constraint_checking_alora = True
113115

114116

117+
@pytest.mark.qualitative
115118
def test_llmaj_req_does_not_use_alora(session, backend):
116119
backend.default_to_constraint_checking_alora = True # type: ignore
117120
answer = session.instruct(
@@ -127,12 +130,13 @@ def test_llmaj_req_does_not_use_alora(session, backend):
127130
assert isinstance(val_result, ValidationResult)
128131
assert str(val_result.reason) not in ["Y", "N"]
129132

130-
133+
@pytest.mark.qualitative
131134
def test_instruct(session):
132135
result = session.instruct("Compute 1+1.")
133136
print(result)
134137

135138

139+
@pytest.mark.qualitative
136140
def test_multiturn(session):
137141
session.instruct("Compute 1+1")
138142
beta = session.instruct(
@@ -143,6 +147,7 @@ def test_multiturn(session):
143147
print(words)
144148

145149

150+
@pytest.mark.qualitative
146151
def test_format(session):
147152
class Person(pydantic.BaseModel):
148153
name: str
@@ -172,7 +177,7 @@ class Email(pydantic.BaseModel):
172177
"The email address should be at example.com"
173178
)
174179

175-
180+
@pytest.mark.qualitative
176181
def test_generate_from_raw(session):
177182
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
178183

@@ -183,6 +188,7 @@ def test_generate_from_raw(session):
183188
assert len(results) == len(prompts)
184189

185190

191+
@pytest.mark.qualitative
186192
def test_generate_from_raw_with_format(session):
187193
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
188194

test/backends/test_ollama.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from mellea import start_session, SimpleContext
2-
from mellea.stdlib.base import CBlock
3-
from mellea.stdlib.requirement import Requirement
4-
import pydantic
51
import json
2+
3+
import pydantic
4+
import pytest
65
from typing_extensions import Annotated
6+
7+
from mellea import SimpleContext, start_session
78
from mellea.backends.types import ModelOption
8-
import pytest
9+
from mellea.stdlib.base import CBlock
10+
from mellea.stdlib.requirement import Requirement
911

1012

1113
@pytest.fixture(scope="function")
@@ -15,6 +17,8 @@ def session():
1517
yield session
1618
session.reset()
1719

20+
21+
@pytest.mark.qualitative
1822
def test_simple_instruct(session):
1923
result = session.instruct(
2024
"Write an email to Hendrik trying to sell him self-sealing stembolts."
@@ -23,6 +27,8 @@ def test_simple_instruct(session):
2327
assert "chat_response" in result._meta
2428
assert result._meta["chat_response"].message.role == "assistant"
2529

30+
31+
@pytest.mark.qualitative
2632
def test_instruct_with_requirement(session):
2733
response = session.instruct(
2834
"Write an email to Hendrik convincing him to buy some self-sealing stembolts."
@@ -45,12 +51,14 @@ def test_instruct_with_requirement(session):
4551
)
4652
print(results)
4753

54+
@pytest.mark.qualitative
4855
def test_chat(session):
4956
output_message = session.chat("What is 1+1?")
50-
assert (
51-
"2" in output_message.content
52-
), f"Expected a message with content containing 2 but found {output_message}"
57+
assert "2" in output_message.content, (
58+
f"Expected a message with content containing 2 but found {output_message}"
59+
)
5360

61+
@pytest.mark.qualitative
5462
def test_format(session):
5563
class Person(pydantic.BaseModel):
5664
name: str
@@ -83,6 +91,7 @@ class Email(pydantic.BaseModel):
8391
# assert email.to.email_address.endswith("example.com")
8492
pass
8593

94+
@pytest.mark.qualitative
8695
def test_generate_from_raw(session):
8796
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
8897

@@ -113,9 +122,9 @@ class Answer(pydantic.BaseModel):
113122
try:
114123
answer = Answer.model_validate_json(random_result.value)
115124
except pydantic.ValidationError as e:
116-
assert (
117-
False
118-
), f"formatting directive failed for {random_result.value}: {e.json()}"
125+
assert False, (
126+
f"formatting directive failed for {random_result.value}: {e.json()}"
127+
)
119128

120129

121130
if __name__ == "__main__":

test/backends/test_openai_ollama.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,54 @@
11
# test/rits_backend_tests/test_openai_integration.py
2-
from mellea import MelleaSession
3-
from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk
4-
from mellea.backends.openai import OpenAIBackend
5-
from mellea.backends.formatter import TemplateFormatter
6-
from mellea.backends.types import ModelOption
2+
import os
73

84
import pydantic
9-
from typing_extensions import Annotated
105
import pytest
6+
from typing_extensions import Annotated
7+
8+
from mellea import MelleaSession
9+
from mellea.backends.formatter import TemplateFormatter
10+
from mellea.backends.model_ids import META_LLAMA_3_2_1B
11+
from mellea.backends.openai import OpenAIBackend
12+
from mellea.backends.types import ModelOption
13+
from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk
1114

1215

1316
@pytest.fixture(scope="module")
14-
def backend():
17+
def backend(gh_run: int):
1518
"""Shared OpenAI backend configured for Ollama."""
16-
return OpenAIBackend(
17-
model_id="granite3.3:8b",
18-
formatter=TemplateFormatter(model_id="ibm-granite/granite-3.2-8b-instruct"),
19-
base_url="http://localhost:11434/v1",
19+
if gh_run == 1:
20+
return OpenAIBackend(
21+
model_id=META_LLAMA_3_2_1B,
22+
formatter=TemplateFormatter(model_id=META_LLAMA_3_2_1B),
23+
base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1",
2024
api_key="ollama",
2125
)
26+
else:
27+
return OpenAIBackend(
28+
model_id="granite3.3:8b",
29+
formatter=TemplateFormatter(model_id="ibm-granite/granite-3.2-8b-instruct"),
30+
base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1",
31+
api_key="ollama",
32+
)
2233

2334

2435
@pytest.fixture(scope="function")
25-
def session(backend):
36+
def m_session(backend):
2637
"""Fresh OpenAI session for each test."""
2738
session = MelleaSession(backend, ctx=LinearContext(is_chat_context=True))
2839
yield session
2940
session.reset()
3041

31-
def test_instruct(session):
32-
result = session.instruct("Compute 1+1.")
42+
@pytest.mark.qualitative
43+
def test_instruct(m_session):
44+
result = m_session.instruct("Compute 1+1.")
3345
assert isinstance(result, ModelOutputThunk)
3446
assert "2" in result.value # type: ignore
3547

36-
def test_multiturn(session):
37-
session.instruct("What is the capital of France?")
38-
answer = session.instruct("Tell me the answer to the previous question.")
48+
@pytest.mark.qualitative
49+
def test_multiturn(m_session):
50+
m_session.instruct("What is the capital of France?")
51+
answer = m_session.instruct("Tell me the answer to the previous question.")
3952
assert "Paris" in answer.value # type: ignore
4053

4154
# def test_api_timeout_error(self):
@@ -53,7 +66,8 @@ def test_multiturn(session):
5366
# assert "granite3.3:8b" in result.value
5467
# self.m.reset()
5568

56-
def test_format(session):
69+
@pytest.mark.qualitative
70+
def test_format(m_session):
5771
class Person(pydantic.BaseModel):
5872
name: str
5973
# it does not support regex patterns in json schema
@@ -68,7 +82,7 @@ class Email(pydantic.BaseModel):
6882
subject: str
6983
body: str
7084

71-
output = session.instruct(
85+
output = m_session.instruct(
7286
"Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ",
7387
format=Email,
7488
model_options={ModelOption.MAX_NEW_TOKENS: 2**8},

0 commit comments

Comments
 (0)