Skip to content

Commit 586e08c

Browse files
committed
rewriting some huggingface tests to comply with new API
1 parent 03e8b01 commit 586e08c

File tree

1 file changed

+44
-44
lines changed

1 file changed

+44
-44
lines changed

test/backends/test_huggingface.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
1+
import pydantic
2+
import pytest
3+
from typing_extensions import Annotated
4+
15
from mellea import MelleaSession
2-
from mellea.stdlib.base import CBlock, LinearContext
3-
from mellea.backends.huggingface import LocalHFBackend
46
from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras
7+
from mellea.backends.cache import SimpleLRUCache
8+
from mellea.backends.formatter import TemplateFormatter
9+
from mellea.backends.huggingface import LocalHFBackend
10+
from mellea.backends.types import ModelOption
11+
from mellea.stdlib.base import CBlock, LinearContext
512
from mellea.stdlib.requirement import (
6-
Requirement,
713
ALoraRequirement,
814
LLMaJRequirement,
15+
Requirement,
916
ValidationResult,
1017
)
11-
from mellea.backends.formatter import TemplateFormatter
12-
from mellea.backends.cache import SimpleLRUCache
13-
from mellea.backends.types import ModelOption
14-
import pydantic
15-
16-
from typing_extensions import Annotated
17-
18-
import pytest
1918

2019

2120
@pytest.fixture(scope="module")
@@ -49,9 +48,13 @@ def test_system_prompt(session):
4948
def test_constraint_alora(session, backend):
5049
answer = session.instruct(
5150
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.",
52-
model_options={ModelOption.MAX_NEW_TOKENS: 300}, # Until aloras get a bit better, try not to abruptly end generation.
51+
model_options={
52+
ModelOption.MAX_NEW_TOKENS: 300
53+
}, # Until aloras get a bit better, try not to abruptly end generation.
5354
)
54-
alora_output = backend.get_aloras()[0].generate_using_strings(
55+
alora_output = backend.get_aloras()[
56+
0
57+
].generate_using_strings(
5558
input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa",
5659
response=str(answer),
5760
constraint="The answer mention that there is a b in the middle of one of the strings but not the other.",
@@ -68,12 +71,12 @@ def test_constraint_lora_with_requirement(session, backend):
6871
assert session.backend._use_caches
6972
assert backend._cache.current_size() != 0
7073
validation_outputs = session.validate(
71-
"The answer should mention that there is a b in the middle of one of the strings but not the other.",
72-
return_full_validation_results=True,
74+
"The answer should mention that there is a b in the middle of one of the strings but not the other."
7375
)
7476
assert len(validation_outputs) == 1
75-
alora_output, valuation_boolean = validation_outputs[0]
76-
assert str(alora_output) in ["Y", "N"]
77+
val_result = validation_outputs[0]
78+
assert isinstance(val_result, ValidationResult)
79+
assert str(val_result.reason) in ["Y", "N"]
7780

7881

7982
def test_constraint_lora_override(session, backend):
@@ -82,12 +85,14 @@ def test_constraint_lora_override(session, backend):
8285
"Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
8386
)
8487
validation_outputs = session.validate(
85-
"The answer should mention that there is a b in the middle of one of the strings but not the other.",
86-
return_full_validation_results=True,
88+
"The answer should mention that there is a b in the middle of one of the strings but not the other."
8789
)
8890
assert len(validation_outputs) == 1
89-
non_alora_output, _ = validation_outputs[0]
90-
assert str(non_alora_output) not in ["Y", "N"]
91+
val_result = validation_outputs[0]
92+
assert isinstance(val_result, ValidationResult)
93+
assert (
94+
str(val_result.reason) in ["Y", "N", "Yes.", "No."]
95+
) # Checking for any type of result that LLM may output. But might need to be more robust.
9196
backend.default_to_constraint_checking_alora = True
9297

9398

@@ -99,12 +104,12 @@ def test_constraint_lora_override_does_not_override_alora(session, backend):
99104
validation_outputs = session.validate(
100105
ALoraRequirement(
101106
"The answer should mention that there is a b in the middle of one of the strings but not the other."
102-
),
103-
return_full_validation_results=True,
107+
)
104108
)
105109
assert len(validation_outputs) == 1
106-
non_alora_output, _ = validation_outputs[0]
107-
assert str(non_alora_output) in ["Y", "N"]
110+
val_result = validation_outputs[0]
111+
assert isinstance(val_result, ValidationResult)
112+
assert str(val_result.reason) in ["Y", "N"]
108113
backend.default_to_constraint_checking_alora = True
109114

110115

@@ -116,12 +121,12 @@ def test_llmaj_req_does_not_use_alora(session, backend):
116121
validation_outputs = session.validate(
117122
LLMaJRequirement(
118123
"The answer should mention that there is a b in the middle of one of the strings but not the other."
119-
),
120-
return_full_validation_results=True,
124+
)
121125
)
122126
assert len(validation_outputs) == 1
123-
non_alora_output, _ = validation_outputs[0]
124-
assert str(non_alora_output) not in ["Y", "N"]
127+
val_result = validation_outputs[0]
128+
assert isinstance(val_result, ValidationResult)
129+
assert str(val_result.reason) not in ["Y", "N"]
125130

126131

127132
def test_instruct(session):
@@ -135,18 +140,15 @@ def test_multiturn(session):
135140
"Take the result of the previous sum and find the corresponding letter in the greek alphabet."
136141
)
137142
assert "β" in str(beta).lower()
138-
words = session.instruct(
139-
"Now list five English words that start with that letter."
140-
)
143+
words = session.instruct("Now list five English words that start with that letter.")
141144
print(words)
142145

143146

144147
def test_format(session):
145148
class Person(pydantic.BaseModel):
146149
name: str
147150
email_address: Annotated[
148-
str,
149-
pydantic.StringConstraints(pattern=r"[a-zA-Z]{5,10}@example\.com"),
151+
str, pydantic.StringConstraints(pattern=r"[a-zA-Z]{5,10}@example\.com")
150152
]
151153

152154
class Email(pydantic.BaseModel):
@@ -166,12 +168,10 @@ class Email(pydantic.BaseModel):
166168
print(email)
167169

168170
print("address:", email.to.email_address)
169-
assert (
170-
"@" in email.to.email_address
171-
), "The @ sign should be in the meail address."
172-
assert email.to.email_address.endswith(
173-
"example.com"
174-
), "The email address should be at example.com"
171+
assert "@" in email.to.email_address, "The @ sign should be in the meail address."
172+
assert email.to.email_address.endswith("example.com"), (
173+
"The email address should be at example.com"
174+
)
175175

176176

177177
def test_generate_from_raw(session):
@@ -203,12 +203,12 @@ class Answer(pydantic.BaseModel):
203203
try:
204204
answer = Answer.model_validate_json(random_result.value)
205205
except pydantic.ValidationError as e:
206-
assert (
207-
False
208-
), f"formatting directive failed for {random_result.value}: {e.json()}"
206+
assert False, (
207+
f"formatting directive failed for {random_result.value}: {e.json()}"
208+
)
209209

210210

211211
if __name__ == "__main__":
212212
import pytest
213213

214-
pytest.main([__file__])
214+
pytest.main([__file__])

0 commit comments

Comments
 (0)