Skip to content

Commit 68cf8c6

Browse files
committed
adding modified tests with optional LLM runs
1 parent 9498f0f commit 68cf8c6

File tree

10 files changed

+142
-94
lines changed

10 files changed

+142
-94
lines changed

test/backends/test_huggingface.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
@pytest.fixture(scope="module")
2222
def backend():
2323
"""Shared HuggingFace backend for all tests in this module."""
24+
# TODO: find a smalle 1B model to do Alora stuff on github actions.
2425
backend = LocalHFBackend(
2526
model_id="ibm-granite/granite-3.2-8b-instruct",
2627
formatter=TemplateFormatter(model_id="ibm-granite/granite-4.0-tiny-preview"),
@@ -37,15 +38,15 @@ def session(backend):
3738
yield session
3839
session.reset()
3940

40-
41+
@pytest.mark.llm
4142
def test_system_prompt(session):
4243
result = session.chat(
4344
"Where are we going?",
4445
model_options={ModelOption.SYSTEM_PROMPT: "Talk like a pirate."},
4546
)
4647
print(result)
4748

48-
49+
@pytest.mark.llm
4950
def test_constraint_alora(session, backend):
5051
answer = session.instruct(
5152
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.",
@@ -63,7 +64,7 @@ def test_constraint_alora(session, backend):
6364
)
6465
assert alora_output in ["Y", "N"], alora_output
6566

66-
67+
@pytest.mark.llm
6768
def test_constraint_lora_with_requirement(session, backend):
6869
answer = session.instruct(
6970
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
@@ -79,7 +80,7 @@ def test_constraint_lora_with_requirement(session, backend):
7980
assert isinstance(val_result, ValidationResult)
8081
assert str(val_result.reason) in ["Y", "N"]
8182

82-
83+
@pytest.mark.llm
8384
def test_constraint_lora_override(session, backend):
8485
backend.default_to_constraint_checking_alora = False # type: ignore
8586
answer = session.instruct(
@@ -94,7 +95,7 @@ def test_constraint_lora_override(session, backend):
9495
assert isinstance(default_output_to_bool(str(val_result.reason)), bool)
9596
backend.default_to_constraint_checking_alora = True
9697

97-
98+
@pytest.mark.llm
9899
def test_constraint_lora_override_does_not_override_alora(session, backend):
99100
backend.default_to_constraint_checking_alora = False # type: ignore
100101
answer = session.instruct(
@@ -111,7 +112,7 @@ def test_constraint_lora_override_does_not_override_alora(session, backend):
111112
assert str(val_result.reason) in ["Y", "N"]
112113
backend.default_to_constraint_checking_alora = True
113114

114-
115+
@pytest.mark.llm
115116
def test_llmaj_req_does_not_use_alora(session, backend):
116117
backend.default_to_constraint_checking_alora = True # type: ignore
117118
answer = session.instruct(
@@ -127,12 +128,12 @@ def test_llmaj_req_does_not_use_alora(session, backend):
127128
assert isinstance(val_result, ValidationResult)
128129
assert str(val_result.reason) not in ["Y", "N"]
129130

130-
131+
@pytest.mark.llm
131132
def test_instruct(session):
132133
result = session.instruct("Compute 1+1.")
133134
print(result)
134135

135-
136+
@pytest.mark.llm
136137
def test_multiturn(session):
137138
session.instruct("Compute 1+1")
138139
beta = session.instruct(
@@ -142,7 +143,7 @@ def test_multiturn(session):
142143
words = session.instruct("Now list five English words that start with that letter.")
143144
print(words)
144145

145-
146+
@pytest.mark.llm
146147
def test_format(session):
147148
class Person(pydantic.BaseModel):
148149
name: str
@@ -172,7 +173,7 @@ class Email(pydantic.BaseModel):
172173
"The email address should be at example.com"
173174
)
174175

175-
176+
@pytest.mark.llm
176177
def test_generate_from_raw(session):
177178
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
178179

@@ -182,7 +183,7 @@ def test_generate_from_raw(session):
182183

183184
assert len(results) == len(prompts)
184185

185-
186+
@pytest.mark.llm
186187
def test_generate_from_raw_with_format(session):
187188
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
188189

test/backends/test_ollama.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from mellea import start_session, SimpleContext
2-
from mellea.stdlib.base import CBlock
3-
from mellea.stdlib.requirement import Requirement
4-
import pydantic
51
import json
2+
3+
import pydantic
4+
import pytest
65
from typing_extensions import Annotated
6+
7+
from mellea import SimpleContext, start_session
78
from mellea.backends.types import ModelOption
8-
import pytest
9+
from mellea.stdlib.base import CBlock
10+
from mellea.stdlib.requirement import Requirement
911

1012

1113
@pytest.fixture(scope="function")
@@ -15,6 +17,8 @@ def session():
1517
yield session
1618
session.reset()
1719

20+
21+
@pytest.mark.llm
1822
def test_simple_instruct(session):
1923
result = session.instruct(
2024
"Write an email to Hendrik trying to sell him self-sealing stembolts."
@@ -23,6 +27,8 @@ def test_simple_instruct(session):
2327
assert "chat_response" in result._meta
2428
assert result._meta["chat_response"].message.role == "assistant"
2529

30+
31+
@pytest.mark.llm
2632
def test_instruct_with_requirement(session):
2733
response = session.instruct(
2834
"Write an email to Hendrik convincing him to buy some self-sealing stembolts."
@@ -45,12 +51,14 @@ def test_instruct_with_requirement(session):
4551
)
4652
print(results)
4753

54+
@pytest.mark.llm
4855
def test_chat(session):
4956
output_message = session.chat("What is 1+1?")
50-
assert (
51-
"2" in output_message.content
52-
), f"Expected a message with content containing 2 but found {output_message}"
57+
assert "2" in output_message.content, (
58+
f"Expected a message with content containing 2 but found {output_message}"
59+
)
5360

61+
@pytest.mark.llm
5462
def test_format(session):
5563
class Person(pydantic.BaseModel):
5664
name: str
@@ -83,6 +91,7 @@ class Email(pydantic.BaseModel):
8391
# assert email.to.email_address.endswith("example.com")
8492
pass
8593

94+
@pytest.mark.llm
8695
def test_generate_from_raw(session):
8796
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
8897

@@ -113,9 +122,9 @@ class Answer(pydantic.BaseModel):
113122
try:
114123
answer = Answer.model_validate_json(random_result.value)
115124
except pydantic.ValidationError as e:
116-
assert (
117-
False
118-
), f"formatting directive failed for {random_result.value}: {e.json()}"
125+
assert False, (
126+
f"formatting directive failed for {random_result.value}: {e.json()}"
127+
)
119128

120129

121130
if __name__ == "__main__":

test/backends/test_openai_ollama.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,52 @@
11
# test/rits_backend_tests/test_openai_integration.py
2+
import pydantic
3+
import pytest
4+
from typing_extensions import Annotated
5+
26
from mellea import MelleaSession
3-
from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk
4-
from mellea.backends.openai import OpenAIBackend
57
from mellea.backends.formatter import TemplateFormatter
8+
from mellea.backends.model_ids import META_LLAMA_3_2_1B
9+
from mellea.backends.openai import OpenAIBackend
610
from mellea.backends.types import ModelOption
7-
8-
import pydantic
9-
from typing_extensions import Annotated
10-
import pytest
11+
from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk
1112

1213

1314
@pytest.fixture(scope="module")
14-
def backend():
15+
def backend(gh_run: int):
1516
"""Shared OpenAI backend configured for Ollama."""
16-
return OpenAIBackend(
17-
model_id="granite3.3:8b",
18-
formatter=TemplateFormatter(model_id="ibm-granite/granite-3.2-8b-instruct"),
17+
if gh_run == 1:
18+
return OpenAIBackend(
19+
model_id=META_LLAMA_3_2_1B,
20+
formatter=TemplateFormatter(model_id=META_LLAMA_3_2_1B),
1921
base_url="http://localhost:11434/v1",
2022
api_key="ollama",
2123
)
24+
else:
25+
return OpenAIBackend(
26+
model_id="granite3.3:8b",
27+
formatter=TemplateFormatter(model_id="ibm-granite/granite-3.2-8b-instruct"),
28+
base_url="http://localhost:11434/v1",
29+
api_key="ollama",
30+
)
2231

2332

2433
@pytest.fixture(scope="function")
25-
def session(backend):
34+
def m_session(backend):
2635
"""Fresh OpenAI session for each test."""
2736
session = MelleaSession(backend, ctx=LinearContext(is_chat_context=True))
2837
yield session
2938
session.reset()
3039

31-
def test_instruct(session):
32-
result = session.instruct("Compute 1+1.")
40+
@pytest.mark.llm
41+
def test_instruct(m_session):
42+
result = m_session.instruct("Compute 1+1.")
3343
assert isinstance(result, ModelOutputThunk)
3444
assert "2" in result.value # type: ignore
3545

36-
def test_multiturn(session):
37-
session.instruct("What is the capital of France?")
38-
answer = session.instruct("Tell me the answer to the previous question.")
46+
@pytest.mark.llm
47+
def test_multiturn(m_session):
48+
m_session.instruct("What is the capital of France?")
49+
answer = m_session.instruct("Tell me the answer to the previous question.")
3950
assert "Paris" in answer.value # type: ignore
4051

4152
# def test_api_timeout_error(self):
@@ -53,7 +64,8 @@ def test_multiturn(session):
5364
# assert "granite3.3:8b" in result.value
5465
# self.m.reset()
5566

56-
def test_format(session):
67+
@pytest.mark.llm
68+
def test_format(m_session):
5769
class Person(pydantic.BaseModel):
5870
name: str
5971
# it does not support regex patterns in json schema
@@ -68,7 +80,7 @@ class Email(pydantic.BaseModel):
6880
subject: str
6981
body: str
7082

71-
output = session.instruct(
83+
output = m_session.instruct(
7284
"Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ",
7385
format=Email,
7486
model_options={ModelOption.MAX_NEW_TOKENS: 2**8},

test/backends/test_watsonx.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# test/rits_backend_tests/test_watsonx_integration.py
22
import os
3-
from mellea import MelleaSession
4-
from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk
5-
from mellea.backends.watsonx import WatsonxAIBackend
6-
from mellea.backends.formatter import TemplateFormatter
7-
from mellea.backends.types import ModelOption
83

94
import pydantic
10-
from typing_extensions import Annotated
115
import pytest
6+
from typing_extensions import Annotated
7+
8+
from mellea import MelleaSession
9+
from mellea.backends.formatter import TemplateFormatter
10+
from mellea.backends.types import ModelOption
11+
from mellea.backends.watsonx import WatsonxAIBackend
12+
from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk
1213

1314

1415
@pytest.fixture(scope="module")
@@ -28,18 +29,21 @@ def session(backend):
2829
session.reset()
2930

3031

31-
32-
32+
@pytest.mark.llm
3333
def test_instruct(session):
3434
result = session.instruct("Compute 1+1.")
3535
assert isinstance(result, ModelOutputThunk)
3636
assert "2" in result.value # type: ignore
3737

38+
39+
@pytest.mark.llm
3840
def test_multiturn(session):
3941
session.instruct("What is the capital of France?")
4042
answer = session.instruct("Tell me the answer to the previous question.")
4143
assert "Paris" in answer.value # type: ignore
4244

45+
46+
@pytest.mark.llm
4347
def test_format(session):
4448
class Person(pydantic.BaseModel):
4549
name: str
@@ -72,6 +76,8 @@ class Email(pydantic.BaseModel):
7276
# assert email.to.email_address.endswith("example.com")
7377
pass
7478

79+
80+
@pytest.mark.llm
7581
def test_generate_from_raw(session):
7682
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
7783

test/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def pytest_runtest_setup(item):
2121
gh_run = int(os.environ.get("GITHUB_ACTION", 0))
2222

2323
if gh_run == 1:
24-
pytest.skip(
24+
pytest.xfail(
2525
reason="Skipping LLM test: got env variable GITHUB_ACTION == 1. Used only in gh workflows."
2626
)
2727

test/stdlib_basics/test_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from mellea.stdlib.base import Component, CBlock
2-
from mellea.stdlib.base import LinearContext
1+
from mellea.stdlib.base import CBlock, Component, LinearContext
32

43

54
def test_cblock():

test/stdlib_basics/test_chat_view.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11

22
import pytest
3-
from mellea.stdlib.base import ModelOutputThunk, LinearContext
4-
from mellea.stdlib.chat import as_chat_history, Message
3+
4+
from mellea.stdlib.base import LinearContext, ModelOutputThunk
5+
from mellea.stdlib.chat import Message, as_chat_history
56
from mellea.stdlib.session import start_session
67

78

0 commit comments

Comments
 (0)