Skip to content

Commit 6a406e6

Browse files
committed
Improve tests isolation
Signed-off-by: elronbandel <[email protected]>
1 parent bde69a9 commit 6a406e6

File tree

6 files changed

+432
-387
lines changed

6 files changed

+432
-387
lines changed

test/backends/test_huggingface.py

Lines changed: 175 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from mellea.stdlib.base import CBlock, LinearContext
33
from mellea.backends.huggingface import LocalHFBackend
44
from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras
5-
from mellea.stdlib.requirement import Requirement, ALoraRequirement, LLMaJRequirement
5+
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement
66
from mellea.backends.formatter import TemplateFormatter
77
from mellea.backends.cache import SimpleLRUCache
88
from mellea.backends.types import ModelOption
@@ -13,190 +13,197 @@
1313
import pytest
1414

1515

16-
class TestHFALoraStuff:
16+
@pytest.fixture(scope="module")
17+
def backend():
18+
"""Shared HuggingFace backend for all tests in this module."""
1719
backend = LocalHFBackend(
1820
model_id="ibm-granite/granite-3.2-8b-instruct",
1921
formatter=TemplateFormatter(model_id="ibm-granite/granite-4.0-tiny-preview"),
2022
cache=SimpleLRUCache(5),
2123
)
22-
m = MelleaSession(backend, ctx=LinearContext())
2324
add_granite_aloras(backend)
25+
return backend
2426

25-
def test_system_prompt(self):
26-
self.m.reset()
27-
result = self.m.chat(
28-
"Where are we going?",
29-
model_options={ModelOption.SYSTEM_PROMPT: "Talk like a pirate."},
30-
)
31-
print(result)
32-
33-
def test_constraint_alora(self):
34-
self.m.reset()
35-
answer = self.m.instruct(
36-
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.",
37-
model_options={ModelOption.MAX_NEW_TOKENS: 300}, # Until aloras get a bit better, try not to abruptly end generation.
38-
)
39-
alora_output = self.backend.get_aloras()[0].generate_using_strings(
40-
input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa",
41-
response=str(answer),
42-
constraint="The answer mention that there is a b in the middle of one of the strings but not the other.",
43-
force_yn=False, # make sure that the alora naturally output Y and N without constrained generation
44-
)
45-
assert alora_output in ["Y", "N"], alora_output
46-
self.m.reset()
47-
48-
def test_constraint_lora_with_requirement(self):
49-
self.m.reset()
50-
answer = self.m.instruct(
51-
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
52-
)
53-
assert self.m.backend._cache is not None # type: ignore
54-
assert self.m.backend._use_caches
55-
assert self.backend._cache.current_size() != 0
56-
validation_outputs = self.m.validate(
57-
"The answer should mention that there is a b in the middle of one of the strings but not the other.",
58-
return_full_validation_results=True,
59-
)
60-
assert len(validation_outputs) == 1
61-
alora_output, valuation_boolean = validation_outputs[0]
62-
assert str(alora_output) in ["Y", "N"]
63-
self.m.reset()
64-
65-
def test_constraint_lora_override(self):
66-
self.m.reset()
67-
self.backend.default_to_constraint_checking_alora = False # type: ignore
68-
answer = self.m.instruct(
69-
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
70-
)
71-
validation_outputs = self.m.validate(
72-
"The answer should mention that there is a b in the middle of one of the strings but not the other.",
73-
return_full_validation_results=True,
74-
)
75-
assert len(validation_outputs) == 1
76-
non_alora_output, _ = validation_outputs[0]
77-
assert str(non_alora_output) not in ["Y", "N"]
78-
self.backend.default_to_constraint_checking_alora = True
79-
self.m.reset()
80-
81-
def test_constraint_lora_override_does_not_override_alora(self):
82-
self.m.reset()
83-
self.backend.default_to_constraint_checking_alora = False # type: ignore
84-
answer = self.m.instruct(
85-
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
86-
)
87-
validation_outputs = self.m.validate(
88-
ALoraRequirement(
89-
"The answer should mention that there is a b in the middle of one of the strings but not the other."
90-
),
91-
return_full_validation_results=True,
92-
)
93-
assert len(validation_outputs) == 1
94-
non_alora_output, _ = validation_outputs[0]
95-
assert str(non_alora_output) in ["Y", "N"]
96-
self.backend.default_to_constraint_checking_alora = True
97-
self.m.reset()
98-
99-
def test_llmaj_req_does_not_use_alora(self):
100-
self.m.reset()
101-
self.backend.default_to_constraint_checking_alora = True # type: ignore
102-
answer = self.m.instruct(
103-
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
104-
)
105-
validation_outputs = self.m.validate(
106-
LLMaJRequirement(
107-
"The answer should mention that there is a b in the middle of one of the strings but not the other."
108-
),
109-
return_full_validation_results=True,
110-
)
111-
assert len(validation_outputs) == 1
112-
non_alora_output, _ = validation_outputs[0]
113-
assert str(non_alora_output) not in ["Y", "N"]
114-
self.m.reset()
115-
116-
def test_instruct(self):
117-
self.m.reset()
118-
result = self.m.instruct("Compute 1+1.")
119-
print(result)
120-
self.m.reset()
121-
122-
def test_multiturn(self):
123-
self.m.instruct("Compute 1+1")
124-
beta = self.m.instruct(
125-
"Take the result of the previous sum and find the corresponding letter in the greek alphabet."
126-
)
127-
assert "β" in str(beta).lower()
128-
words = self.m.instruct(
129-
"Now list five English words that start with that letter."
130-
)
131-
print(words)
132-
self.m.reset()
133-
134-
def test_format(self):
135-
class Person(pydantic.BaseModel):
136-
name: str
137-
email_address: Annotated[
138-
str,
139-
pydantic.StringConstraints(pattern=r"[a-zA-Z]{5,10}@example\.com"),
140-
]
141-
142-
class Email(pydantic.BaseModel):
143-
to: Person
144-
subject: str
145-
body: str
146-
147-
output = self.m.instruct(
148-
"Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ",
149-
format=Email,
150-
model_options={ModelOption.MAX_NEW_TOKENS: 2**8},
151-
)
152-
print("Formatted output:")
153-
email = Email.model_validate_json(
154-
output.value
155-
) # this should succeed because the output should be JSON because we passed in a format= argument...
156-
print(email)
157-
158-
print("address:", email.to.email_address)
159-
assert (
160-
"@" in email.to.email_address
161-
), "The @ sign should be in the meail address."
162-
assert email.to.email_address.endswith(
163-
"example.com"
164-
), "The email address should be at example.com"
16527

166-
def test_generate_from_raw(self):
167-
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
28+
@pytest.fixture(scope="function")
29+
def session(backend):
30+
"""Fresh HuggingFace session for each test."""
31+
session = MelleaSession(backend, ctx=LinearContext())
32+
yield session
33+
session.reset()
34+
35+
36+
def test_system_prompt(session):
37+
result = session.chat(
38+
"Where are we going?",
39+
model_options={ModelOption.SYSTEM_PROMPT: "Talk like a pirate."},
40+
)
41+
print(result)
42+
43+
44+
def test_constraint_alora(session, backend):
45+
answer = session.instruct(
46+
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.",
47+
model_options={ModelOption.MAX_NEW_TOKENS: 300}, # Until aloras get a bit better, try not to abruptly end generation.
48+
)
49+
alora_output = backend.get_aloras()[0].generate_using_strings(
50+
input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa",
51+
response=str(answer),
52+
constraint="The answer mention that there is a b in the middle of one of the strings but not the other.",
53+
force_yn=False, # make sure that the alora naturally output Y and N without constrained generation
54+
)
55+
assert alora_output in ["Y", "N"], alora_output
56+
57+
58+
def test_constraint_lora_with_requirement(session, backend):
59+
answer = session.instruct(
60+
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
61+
)
62+
assert session.backend._cache is not None # type: ignore
63+
assert session.backend._use_caches
64+
assert backend._cache.current_size() != 0
65+
validation_outputs = session.validate(
66+
"The answer should mention that there is a b in the middle of one of the strings but not the other.",
67+
return_full_validation_results=True,
68+
)
69+
assert len(validation_outputs) == 1
70+
alora_output, valuation_boolean = validation_outputs[0]
71+
assert str(alora_output) in ["Y", "N"]
72+
73+
74+
def test_constraint_lora_override(session, backend):
75+
backend.default_to_constraint_checking_alora = False # type: ignore
76+
answer = session.instruct(
77+
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
78+
)
79+
validation_outputs = session.validate(
80+
"The answer should mention that there is a b in the middle of one of the strings but not the other.",
81+
return_full_validation_results=True,
82+
)
83+
assert len(validation_outputs) == 1
84+
non_alora_output, _ = validation_outputs[0]
85+
assert str(non_alora_output) not in ["Y", "N"]
86+
backend.default_to_constraint_checking_alora = True
87+
88+
89+
def test_constraint_lora_override_does_not_override_alora(session, backend):
90+
backend.default_to_constraint_checking_alora = False # type: ignore
91+
answer = session.instruct(
92+
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
93+
)
94+
validation_outputs = session.validate(
95+
ALoraRequirement(
96+
"The answer should mention that there is a b in the middle of one of the strings but not the other."
97+
),
98+
return_full_validation_results=True,
99+
)
100+
assert len(validation_outputs) == 1
101+
non_alora_output, _ = validation_outputs[0]
102+
assert str(non_alora_output) in ["Y", "N"]
103+
backend.default_to_constraint_checking_alora = True
168104

169-
results = self.m.backend._generate_from_raw(
170-
actions=[CBlock(value=prompt) for prompt in prompts], generate_logs=None
171-
)
172105

173-
assert len(results) == len(prompts)
106+
def test_llmaj_req_does_not_use_alora(session, backend):
107+
backend.default_to_constraint_checking_alora = True # type: ignore
108+
answer = session.instruct(
109+
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
110+
)
111+
validation_outputs = session.validate(
112+
LLMaJRequirement(
113+
"The answer should mention that there is a b in the middle of one of the strings but not the other."
114+
),
115+
return_full_validation_results=True,
116+
)
117+
assert len(validation_outputs) == 1
118+
non_alora_output, _ = validation_outputs[0]
119+
assert str(non_alora_output) not in ["Y", "N"]
174120

175-
def test_generate_from_raw_with_format(self):
176-
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
177121

178-
class Answer(pydantic.BaseModel):
179-
name: str
180-
value: int
122+
def test_instruct(session):
123+
result = session.instruct("Compute 1+1.")
124+
print(result)
181125

182-
results = self.m.backend._generate_from_raw(
183-
actions=[CBlock(value=prompt) for prompt in prompts],
184-
format=Answer,
185-
generate_logs=None,
186-
)
187126

188-
assert len(results) == len(prompts)
127+
def test_multiturn(session):
128+
session.instruct("Compute 1+1")
129+
beta = session.instruct(
130+
"Take the result of the previous sum and find the corresponding letter in the greek alphabet."
131+
)
132+
assert "β" in str(beta).lower()
133+
words = session.instruct(
134+
"Now list five English words that start with that letter."
135+
)
136+
print(words)
137+
138+
139+
def test_format(session):
140+
class Person(pydantic.BaseModel):
141+
name: str
142+
email_address: Annotated[
143+
str,
144+
pydantic.StringConstraints(pattern=r"[a-zA-Z]{5,10}@example\.com"),
145+
]
146+
147+
class Email(pydantic.BaseModel):
148+
to: Person
149+
subject: str
150+
body: str
151+
152+
output = session.instruct(
153+
"Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ",
154+
format=Email,
155+
model_options={ModelOption.MAX_NEW_TOKENS: 2**8},
156+
)
157+
print("Formatted output:")
158+
email = Email.model_validate_json(
159+
output.value
160+
) # this should succeed because the output should be JSON because we passed in a format= argument...
161+
print(email)
162+
163+
print("address:", email.to.email_address)
164+
assert (
165+
"@" in email.to.email_address
166+
), "The @ sign should be in the meail address."
167+
assert email.to.email_address.endswith(
168+
"example.com"
169+
), "The email address should be at example.com"
170+
171+
172+
def test_generate_from_raw(session):
173+
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
174+
175+
results = session.backend._generate_from_raw(
176+
actions=[CBlock(value=prompt) for prompt in prompts], generate_logs=None
177+
)
189178

190-
random_result = results[0]
191-
try:
192-
answer = Answer.model_validate_json(random_result.value)
193-
except pydantic.ValidationError as e:
194-
assert (
195-
False
196-
), f"formatting directive failed for {random_result.value}: {e.json()}"
179+
assert len(results) == len(prompts)
180+
181+
182+
def test_generate_from_raw_with_format(session):
183+
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
184+
185+
class Answer(pydantic.BaseModel):
186+
name: str
187+
value: int
188+
189+
results = session.backend._generate_from_raw(
190+
actions=[CBlock(value=prompt) for prompt in prompts],
191+
format=Answer,
192+
generate_logs=None,
193+
)
194+
195+
assert len(results) == len(prompts)
196+
197+
random_result = results[0]
198+
try:
199+
answer = Answer.model_validate_json(random_result.value)
200+
except pydantic.ValidationError as e:
201+
assert (
202+
False
203+
), f"formatting directive failed for {random_result.value}: {e.json()}"
197204

198205

199206
if __name__ == "__main__":
200207
import pytest
201208

202-
pytest.main([__file__])
209+
pytest.main([__file__])

0 commit comments

Comments
 (0)