1+ import pydantic
2+ import pytest
3+ from typing_extensions import Annotated
4+
15from mellea import MelleaSession
2- from mellea .stdlib .base import CBlock , LinearContext
3- from mellea .backends .huggingface import LocalHFBackend
46from mellea .backends .aloras .huggingface .granite_aloras import add_granite_aloras
7+ from mellea .backends .cache import SimpleLRUCache
8+ from mellea .backends .formatter import TemplateFormatter
9+ from mellea .backends .huggingface import LocalHFBackend
10+ from mellea .backends .types import ModelOption
11+ from mellea .stdlib .base import CBlock , LinearContext
512from mellea .stdlib .requirement import (
6- Requirement ,
713 ALoraRequirement ,
814 LLMaJRequirement ,
15+ Requirement ,
916 ValidationResult ,
1017)
11- from mellea .backends .formatter import TemplateFormatter
12- from mellea .backends .cache import SimpleLRUCache
13- from mellea .backends .types import ModelOption
14- import pydantic
15-
16- from typing_extensions import Annotated
17-
18- import pytest
1918
2019
2120@pytest .fixture (scope = "module" )
@@ -49,9 +48,13 @@ def test_system_prompt(session):
4948def test_constraint_alora (session , backend ):
5049 answer = session .instruct (
5150 "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question." ,
52- model_options = {ModelOption .MAX_NEW_TOKENS : 300 }, # Until aloras get a bit better, try not to abruptly end generation.
51+ model_options = {
52+ ModelOption .MAX_NEW_TOKENS : 300
53+ }, # Until aloras get a bit better, try not to abruptly end generation.
5354 )
54- alora_output = backend .get_aloras ()[0 ].generate_using_strings (
55+ alora_output = backend .get_aloras ()[
56+ 0
57+ ].generate_using_strings (
5558 input = "Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ,
5659 response = str (answer ),
5760 constraint = "The answer mention that there is a b in the middle of one of the strings but not the other." ,
@@ -68,12 +71,12 @@ def test_constraint_lora_with_requirement(session, backend):
6871 assert session .backend ._use_caches
6972 assert backend ._cache .current_size () != 0
7073 validation_outputs = session .validate (
71- "The answer should mention that there is a b in the middle of one of the strings but not the other." ,
72- return_full_validation_results = True ,
74+ "The answer should mention that there is a b in the middle of one of the strings but not the other."
7375 )
7476 assert len (validation_outputs ) == 1
75- alora_output , valuation_boolean = validation_outputs [0 ]
76- assert str (alora_output ) in ["Y" , "N" ]
77+ val_result = validation_outputs [0 ]
78+ assert isinstance (val_result , ValidationResult )
79+ assert str (val_result .reason ) in ["Y" , "N" ]
7780
7881
7982def test_constraint_lora_override (session , backend ):
@@ -82,12 +85,14 @@ def test_constraint_lora_override(session, backend):
8285 "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
8386 )
8487 validation_outputs = session .validate (
85- "The answer should mention that there is a b in the middle of one of the strings but not the other." ,
86- return_full_validation_results = True ,
88+ "The answer should mention that there is a b in the middle of one of the strings but not the other."
8789 )
8890 assert len (validation_outputs ) == 1
89- non_alora_output , _ = validation_outputs [0 ]
90- assert str (non_alora_output ) not in ["Y" , "N" ]
91+ val_result = validation_outputs [0 ]
92+ assert isinstance (val_result , ValidationResult )
93+ assert (
94+ str (val_result .reason ) in ["Y" , "N" , "Yes." , "No." ]
95+ ) # Checking for any type of result that LLM may output. But might need to be more robust.
9196 backend .default_to_constraint_checking_alora = True
9297
9398
@@ -99,12 +104,12 @@ def test_constraint_lora_override_does_not_override_alora(session, backend):
99104 validation_outputs = session .validate (
100105 ALoraRequirement (
101106 "The answer should mention that there is a b in the middle of one of the strings but not the other."
102- ),
103- return_full_validation_results = True ,
107+ )
104108 )
105109 assert len (validation_outputs ) == 1
106- non_alora_output , _ = validation_outputs [0 ]
107- assert str (non_alora_output ) in ["Y" , "N" ]
110+ val_result = validation_outputs [0 ]
111+ assert isinstance (val_result , ValidationResult )
112+ assert str (val_result .reason ) in ["Y" , "N" ]
108113 backend .default_to_constraint_checking_alora = True
109114
110115
@@ -116,12 +121,12 @@ def test_llmaj_req_does_not_use_alora(session, backend):
116121 validation_outputs = session .validate (
117122 LLMaJRequirement (
118123 "The answer should mention that there is a b in the middle of one of the strings but not the other."
119- ),
120- return_full_validation_results = True ,
124+ )
121125 )
122126 assert len (validation_outputs ) == 1
123- non_alora_output , _ = validation_outputs [0 ]
124- assert str (non_alora_output ) not in ["Y" , "N" ]
127+ val_result = validation_outputs [0 ]
128+ assert isinstance (val_result , ValidationResult )
129+ assert str (val_result .reason ) not in ["Y" , "N" ]
125130
126131
127132def test_instruct (session ):
@@ -135,18 +140,15 @@ def test_multiturn(session):
135140 "Take the result of the previous sum and find the corresponding letter in the greek alphabet."
136141 )
137142 assert "β" in str (beta ).lower ()
138- words = session .instruct (
139- "Now list five English words that start with that letter."
140- )
143+ words = session .instruct ("Now list five English words that start with that letter." )
141144 print (words )
142145
143146
144147def test_format (session ):
145148 class Person (pydantic .BaseModel ):
146149 name : str
147150 email_address : Annotated [
148- str ,
149- pydantic .StringConstraints (pattern = r"[a-zA-Z]{5,10}@example\.com" ),
151+ str , pydantic .StringConstraints (pattern = r"[a-zA-Z]{5,10}@example\.com" )
150152 ]
151153
152154 class Email (pydantic .BaseModel ):
@@ -166,12 +168,10 @@ class Email(pydantic.BaseModel):
166168 print (email )
167169
168170 print ("address:" , email .to .email_address )
169- assert (
170- "@" in email .to .email_address
171- ), "The @ sign should be in the meail address."
172- assert email .to .email_address .endswith (
173- "example.com"
174- ), "The email address should be at example.com"
171+ assert "@" in email .to .email_address , "The @ sign should be in the meail address."
172+ assert email .to .email_address .endswith ("example.com" ), (
173+ "The email address should be at example.com"
174+ )
175175
176176
177177def test_generate_from_raw (session ):
@@ -203,12 +203,12 @@ class Answer(pydantic.BaseModel):
203203 try :
204204 answer = Answer .model_validate_json (random_result .value )
205205 except pydantic .ValidationError as e :
206- assert (
207- False
208- ), f"formatting directive failed for { random_result . value } : { e . json () } "
206+ assert False , (
207+ f"formatting directive failed for { random_result . value } : { e . json () } "
208+ )
209209
210210
211211if __name__ == "__main__" :
212212 import pytest
213213
214- pytest .main ([__file__ ])
214+ pytest .main ([__file__ ])
0 commit comments