Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ jobs:
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12'] # Need to add 3.13 once we resolve outlines issues.
env:
CICD: 1
OLLAMA_HOST: "127.0.0.1:5000"
steps:
- uses: actions/checkout@v4
- name: Install uv and set the python version
Expand All @@ -31,9 +34,22 @@ jobs:
path: ~/.cache/pre-commit
key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }}
- name: Install dependencies
run: uv sync --frozen --all-extras
run: uv sync --frozen --all-extras --group dev
- name: Check style and run tests
run: pre-commit run --all-files
- name: Send failure message
- name: Send failure message pre-commit
if: failure() # This step will only run if a previous step failed
run: echo "The quality verification failed. Please run precommit "
- name: Install Ollama
run: curl -fsSL https://ollama.com/install.sh | sh
- name: Start serving ollama
run: nohup ollama serve &
- name: Pull Llama 3.2:1b model
run: ollama pull llama3.2:1b

- name: Run Tests
run: uv run -m pytest -v test
- name: Send failure message tests
if: failure() # This step will only run if a previous step failed
run: echo "Tests failed. Please verify that tests are working locally."

4 changes: 4 additions & 0 deletions mellea/backends/model_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class ModelIdentifier:
ollama_name="llama-guard3:1b", hf_model_name="unsloth/Llama-Guard-3-1B"
)

META_LLAMA_3_2_1B = ModelIdentifier(
ollama_name="llama3.2:1b", hf_model_name="unsloth/Llama-3.2-1B"
)

########################
#### Mistral models ####
########################
Expand Down
15 changes: 6 additions & 9 deletions mellea/stdlib/mify.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def _get_all_fields(self) -> dict[str, Any]:
if self._fields_exclude:
fields_exclude = self._fields_exclude

# This includes fields defined by any superclasses, as long as it's not object.
all_fields = _get_non_duplicate_fields(self, object)
# This includes fields defined by any superclasses, as long as it's not Protocol.
all_fields = _get_non_duplicate_fields(self, Protocol)

# It does matter if include is an empty set. Handle it's cases here.
if self._fields_include is not None:
Expand Down Expand Up @@ -366,18 +366,15 @@ def mification(obj: T) -> T:


def _get_non_duplicate_members(
object: object, check_duplicates: object
obj: object, check_duplicates: object
) -> dict[str, Callable]:
"""Returns all methods/functions unique to the object."""
members = dict(
inspect.getmembers(
object,
obj,
# Checks for ismethod or isfunction because of the methods added from the MifiedProtocol.
predicate=lambda x: inspect.ismethod(x)
or (
inspect.isfunction(x)
and x.__name__ not in dict(inspect.getmembers(check_duplicates)).keys()
),
predicate=lambda x: (inspect.ismethod(x) or inspect.isfunction(x))
and x.__name__ not in dict(inspect.getmembers(check_duplicates)).keys(),
)
)
return members
Expand Down
25 changes: 0 additions & 25 deletions mellea/stdlib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,31 +414,6 @@ def validate(

return rvs

def req(self, *args, **kwargs):
"""Shorthand for Requirement.__init__(...)."""
return req(*args, **kwargs)

def check(self, *args, **kwargs):
"""Shorthand for Requirement.__init__(..., check_only=True)."""
return check(*args, **kwargs)

def load_default_aloras(self):
"""Loads the default Aloras for this model, if they exist and if the backend supports."""
from mellea.backends.huggingface import LocalHFBackend

if self.backend.model_id == IBM_GRANITE_3_2_8B and isinstance(
self.backend, LocalHFBackend
):
from mellea.backends.aloras.huggingface.granite_aloras import (
add_granite_aloras,
)

add_granite_aloras(self.backend)
return
self._session_logger.warning(
"This model/backend combination does not support any aloras."
)

def genslot(
self,
gen_slot: Component,
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,8 @@ skip = 'requirements.txt,uv.lock'
[tool.mypy]
disable_error_code = ["empty-body", "import-untyped"]
python_version = "3.10"

[tool.pytest.ini_options]
markers = [
"qualitative: Marks the test as needing an exact output from an LLM; set by an ENV variable for CICD. All tests marked with this will xfail in CI/CD"
]
16 changes: 11 additions & 5 deletions test/backends/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def session(backend):
yield session
session.reset()


@pytest.mark.qualitative
def test_system_prompt(session):
result = session.chat(
"Where are we going?",
model_options={ModelOption.SYSTEM_PROMPT: "Talk like a pirate."},
)
print(result)


@pytest.mark.qualitative
def test_constraint_alora(session, backend):
answer = session.instruct(
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.",
Expand All @@ -63,7 +63,7 @@ def test_constraint_alora(session, backend):
)
assert alora_output in ["Y", "N"], alora_output


@pytest.mark.qualitative
def test_constraint_lora_with_requirement(session, backend):
answer = session.instruct(
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
Expand All @@ -80,6 +80,7 @@ def test_constraint_lora_with_requirement(session, backend):
assert str(val_result.reason) in ["Y", "N"]


@pytest.mark.qualitative
def test_constraint_lora_override(session, backend):
backend.default_to_constraint_checking_alora = False # type: ignore
answer = session.instruct(
Expand All @@ -95,6 +96,7 @@ def test_constraint_lora_override(session, backend):
backend.default_to_constraint_checking_alora = True


@pytest.mark.qualitative
def test_constraint_lora_override_does_not_override_alora(session, backend):
backend.default_to_constraint_checking_alora = False # type: ignore
answer = session.instruct(
Expand All @@ -112,6 +114,7 @@ def test_constraint_lora_override_does_not_override_alora(session, backend):
backend.default_to_constraint_checking_alora = True


@pytest.mark.qualitative
def test_llmaj_req_does_not_use_alora(session, backend):
backend.default_to_constraint_checking_alora = True # type: ignore
answer = session.instruct(
Expand All @@ -127,12 +130,13 @@ def test_llmaj_req_does_not_use_alora(session, backend):
assert isinstance(val_result, ValidationResult)
assert str(val_result.reason) not in ["Y", "N"]


@pytest.mark.qualitative
def test_instruct(session):
result = session.instruct("Compute 1+1.")
print(result)


@pytest.mark.qualitative
def test_multiturn(session):
session.instruct("Compute 1+1")
beta = session.instruct(
Expand All @@ -143,6 +147,7 @@ def test_multiturn(session):
print(words)


@pytest.mark.qualitative
def test_format(session):
class Person(pydantic.BaseModel):
name: str
Expand Down Expand Up @@ -172,7 +177,7 @@ class Email(pydantic.BaseModel):
"The email address should be at example.com"
)


@pytest.mark.qualitative
def test_generate_from_raw(session):
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]

Expand All @@ -183,6 +188,7 @@ def test_generate_from_raw(session):
assert len(results) == len(prompts)


@pytest.mark.qualitative
def test_generate_from_raw_with_format(session):
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]

Expand Down
31 changes: 20 additions & 11 deletions test/backends/test_ollama.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from mellea import start_session, SimpleContext
from mellea.stdlib.base import CBlock
from mellea.stdlib.requirement import Requirement
import pydantic
import json

import pydantic
import pytest
from typing_extensions import Annotated

from mellea import SimpleContext, start_session
from mellea.backends.types import ModelOption
import pytest
from mellea.stdlib.base import CBlock
from mellea.stdlib.requirement import Requirement


@pytest.fixture(scope="function")
Expand All @@ -15,6 +17,8 @@ def session():
yield session
session.reset()


@pytest.mark.qualitative
def test_simple_instruct(session):
result = session.instruct(
"Write an email to Hendrik trying to sell him self-sealing stembolts."
Expand All @@ -23,6 +27,8 @@ def test_simple_instruct(session):
assert "chat_response" in result._meta
assert result._meta["chat_response"].message.role == "assistant"


@pytest.mark.qualitative
def test_instruct_with_requirement(session):
response = session.instruct(
"Write an email to Hendrik convincing him to buy some self-sealing stembolts."
Expand All @@ -45,12 +51,14 @@ def test_instruct_with_requirement(session):
)
print(results)

@pytest.mark.qualitative
def test_chat(session):
output_message = session.chat("What is 1+1?")
assert (
"2" in output_message.content
), f"Expected a message with content containing 2 but found {output_message}"
assert "2" in output_message.content, (
f"Expected a message with content containing 2 but found {output_message}"
)

@pytest.mark.qualitative
def test_format(session):
class Person(pydantic.BaseModel):
name: str
Expand Down Expand Up @@ -83,6 +91,7 @@ class Email(pydantic.BaseModel):
# assert email.to.email_address.endswith("example.com")
pass

@pytest.mark.qualitative
def test_generate_from_raw(session):
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]

Expand Down Expand Up @@ -113,9 +122,9 @@ class Answer(pydantic.BaseModel):
try:
answer = Answer.model_validate_json(random_result.value)
except pydantic.ValidationError as e:
assert (
False
), f"formatting directive failed for {random_result.value}: {e.json()}"
assert False, (
f"formatting directive failed for {random_result.value}: {e.json()}"
)


if __name__ == "__main__":
Expand Down
52 changes: 33 additions & 19 deletions test/backends/test_openai_ollama.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,54 @@
# test/rits_backend_tests/test_openai_integration.py
from mellea import MelleaSession
from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk
from mellea.backends.openai import OpenAIBackend
from mellea.backends.formatter import TemplateFormatter
from mellea.backends.types import ModelOption
import os

import pydantic
from typing_extensions import Annotated
import pytest
from typing_extensions import Annotated

from mellea import MelleaSession
from mellea.backends.formatter import TemplateFormatter
from mellea.backends.model_ids import META_LLAMA_3_2_1B
from mellea.backends.openai import OpenAIBackend
from mellea.backends.types import ModelOption
from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk


@pytest.fixture(scope="module")
def backend():
def backend(gh_run: int):
"""Shared OpenAI backend configured for Ollama."""
return OpenAIBackend(
model_id="granite3.3:8b",
formatter=TemplateFormatter(model_id="ibm-granite/granite-3.2-8b-instruct"),
base_url="http://localhost:11434/v1",
if gh_run == 1:
return OpenAIBackend(
model_id=META_LLAMA_3_2_1B,
formatter=TemplateFormatter(model_id=META_LLAMA_3_2_1B),
base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1",
api_key="ollama",
)
else:
return OpenAIBackend(
model_id="granite3.3:8b",
formatter=TemplateFormatter(model_id="ibm-granite/granite-3.2-8b-instruct"),
base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1",
api_key="ollama",
)


@pytest.fixture(scope="function")
def session(backend):
def m_session(backend):
"""Fresh OpenAI session for each test."""
session = MelleaSession(backend, ctx=LinearContext(is_chat_context=True))
yield session
session.reset()

def test_instruct(session):
result = session.instruct("Compute 1+1.")
@pytest.mark.qualitative
def test_instruct(m_session):
result = m_session.instruct("Compute 1+1.")
assert isinstance(result, ModelOutputThunk)
assert "2" in result.value # type: ignore

def test_multiturn(session):
session.instruct("What is the capital of France?")
answer = session.instruct("Tell me the answer to the previous question.")
@pytest.mark.qualitative
def test_multiturn(m_session):
m_session.instruct("What is the capital of France?")
answer = m_session.instruct("Tell me the answer to the previous question.")
assert "Paris" in answer.value # type: ignore

# def test_api_timeout_error(self):
Expand All @@ -53,7 +66,8 @@ def test_multiturn(session):
# assert "granite3.3:8b" in result.value
# self.m.reset()

def test_format(session):
@pytest.mark.qualitative
def test_format(m_session):
class Person(pydantic.BaseModel):
name: str
# it does not support regex patterns in json schema
Expand All @@ -68,7 +82,7 @@ class Email(pydantic.BaseModel):
subject: str
body: str

output = session.instruct(
output = m_session.instruct(
"Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ",
format=Email,
model_options={ModelOption.MAX_NEW_TOKENS: 2**8},
Expand Down
Loading