22from mellea .stdlib .base import CBlock , LinearContext
33from mellea .backends .huggingface import LocalHFBackend
44from mellea .backends .aloras .huggingface .granite_aloras import add_granite_aloras
5- from mellea .stdlib .requirement import Requirement , ALoraRequirement , LLMaJRequirement
5+ from mellea .stdlib .requirement import ALoraRequirement , LLMaJRequirement
66from mellea .backends .formatter import TemplateFormatter
77from mellea .backends .cache import SimpleLRUCache
88from mellea .backends .types import ModelOption
1313import pytest
1414
1515
16- class TestHFALoraStuff :
16+ @pytest .fixture (scope = "module" )
17+ def backend ():
18+ """Shared HuggingFace backend for all tests in this module."""
1719 backend = LocalHFBackend (
1820 model_id = "ibm-granite/granite-3.2-8b-instruct" ,
1921 formatter = TemplateFormatter (model_id = "ibm-granite/granite-4.0-tiny-preview" ),
2022 cache = SimpleLRUCache (5 ),
2123 )
22- m = MelleaSession (backend , ctx = LinearContext ())
2324 add_granite_aloras (backend )
25+ return backend
2426
25- def test_system_prompt (self ):
26- self .m .reset ()
27- result = self .m .chat (
28- "Where are we going?" ,
29- model_options = {ModelOption .SYSTEM_PROMPT : "Talk like a pirate." },
30- )
31- print (result )
32-
33- def test_constraint_alora (self ):
34- self .m .reset ()
35- answer = self .m .instruct (
36- "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question." ,
37- model_options = {ModelOption .MAX_NEW_TOKENS : 300 }, # Until aloras get a bit better, try not to abruptly end generation.
38- )
39- alora_output = self .backend .get_aloras ()[0 ].generate_using_strings (
40- input = "Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ,
41- response = str (answer ),
42- constraint = "The answer mention that there is a b in the middle of one of the strings but not the other." ,
43- force_yn = False , # make sure that the alora naturally output Y and N without constrained generation
44- )
45- assert alora_output in ["Y" , "N" ], alora_output
46- self .m .reset ()
47-
48- def test_constraint_lora_with_requirement (self ):
49- self .m .reset ()
50- answer = self .m .instruct (
51- "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
52- )
53- assert self .m .backend ._cache is not None # type: ignore
54- assert self .m .backend ._use_caches
55- assert self .backend ._cache .current_size () != 0
56- validation_outputs = self .m .validate (
57- "The answer should mention that there is a b in the middle of one of the strings but not the other." ,
58- return_full_validation_results = True ,
59- )
60- assert len (validation_outputs ) == 1
61- alora_output , valuation_boolean = validation_outputs [0 ]
62- assert str (alora_output ) in ["Y" , "N" ]
63- self .m .reset ()
64-
65- def test_constraint_lora_override (self ):
66- self .m .reset ()
67- self .backend .default_to_constraint_checking_alora = False # type: ignore
68- answer = self .m .instruct (
69- "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
70- )
71- validation_outputs = self .m .validate (
72- "The answer should mention that there is a b in the middle of one of the strings but not the other." ,
73- return_full_validation_results = True ,
74- )
75- assert len (validation_outputs ) == 1
76- non_alora_output , _ = validation_outputs [0 ]
77- assert str (non_alora_output ) not in ["Y" , "N" ]
78- self .backend .default_to_constraint_checking_alora = True
79- self .m .reset ()
80-
81- def test_constraint_lora_override_does_not_override_alora (self ):
82- self .m .reset ()
83- self .backend .default_to_constraint_checking_alora = False # type: ignore
84- answer = self .m .instruct (
85- "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
86- )
87- validation_outputs = self .m .validate (
88- ALoraRequirement (
89- "The answer should mention that there is a b in the middle of one of the strings but not the other."
90- ),
91- return_full_validation_results = True ,
92- )
93- assert len (validation_outputs ) == 1
94- non_alora_output , _ = validation_outputs [0 ]
95- assert str (non_alora_output ) in ["Y" , "N" ]
96- self .backend .default_to_constraint_checking_alora = True
97- self .m .reset ()
98-
99- def test_llmaj_req_does_not_use_alora (self ):
100- self .m .reset ()
101- self .backend .default_to_constraint_checking_alora = True # type: ignore
102- answer = self .m .instruct (
103- "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
104- )
105- validation_outputs = self .m .validate (
106- LLMaJRequirement (
107- "The answer should mention that there is a b in the middle of one of the strings but not the other."
108- ),
109- return_full_validation_results = True ,
110- )
111- assert len (validation_outputs ) == 1
112- non_alora_output , _ = validation_outputs [0 ]
113- assert str (non_alora_output ) not in ["Y" , "N" ]
114- self .m .reset ()
115-
116- def test_instruct (self ):
117- self .m .reset ()
118- result = self .m .instruct ("Compute 1+1." )
119- print (result )
120- self .m .reset ()
121-
122- def test_multiturn (self ):
123- self .m .instruct ("Compute 1+1" )
124- beta = self .m .instruct (
125- "Take the result of the previous sum and find the corresponding letter in the greek alphabet."
126- )
127- assert "β" in str (beta ).lower ()
128- words = self .m .instruct (
129- "Now list five English words that start with that letter."
130- )
131- print (words )
132- self .m .reset ()
133-
134- def test_format (self ):
135- class Person (pydantic .BaseModel ):
136- name : str
137- email_address : Annotated [
138- str ,
139- pydantic .StringConstraints (pattern = r"[a-zA-Z]{5,10}@example\.com" ),
140- ]
141-
142- class Email (pydantic .BaseModel ):
143- to : Person
144- subject : str
145- body : str
146-
147- output = self .m .instruct (
148- "Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. " ,
149- format = Email ,
150- model_options = {ModelOption .MAX_NEW_TOKENS : 2 ** 8 },
151- )
152- print ("Formatted output:" )
153- email = Email .model_validate_json (
154- output .value
155- ) # this should succeed because the output should be JSON because we passed in a format= argument...
156- print (email )
157-
158- print ("address:" , email .to .email_address )
159- assert (
160- "@" in email .to .email_address
161- ), "The @ sign should be in the meail address."
162- assert email .to .email_address .endswith (
163- "example.com"
164- ), "The email address should be at example.com"
16527
166- def test_generate_from_raw (self ):
167- prompts = ["what is 1+1?" , "what is 2+2?" , "what is 3+3?" , "what is 4+4?" ]
28+ @pytest .fixture (scope = "function" )
29+ def session (backend ):
30+ """Fresh HuggingFace session for each test."""
31+ session = MelleaSession (backend , ctx = LinearContext ())
32+ yield session
33+ session .reset ()
34+
35+
36+ def test_system_prompt (session ):
37+ result = session .chat (
38+ "Where are we going?" ,
39+ model_options = {ModelOption .SYSTEM_PROMPT : "Talk like a pirate." },
40+ )
41+ print (result )
42+
43+
44+ def test_constraint_alora (session , backend ):
45+ answer = session .instruct (
46+ "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question." ,
47+ model_options = {ModelOption .MAX_NEW_TOKENS : 300 }, # Until aloras get a bit better, try not to abruptly end generation.
48+ )
49+ alora_output = backend .get_aloras ()[0 ].generate_using_strings (
50+ input = "Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ,
51+ response = str (answer ),
52+ constraint = "The answer mention that there is a b in the middle of one of the strings but not the other." ,
53+ force_yn = False , # make sure that the alora naturally output Y and N without constrained generation
54+ )
55+ assert alora_output in ["Y" , "N" ], alora_output
56+
57+
58+ def test_constraint_lora_with_requirement (session , backend ):
59+ answer = session .instruct (
60+ "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
61+ )
62+ assert session .backend ._cache is not None # type: ignore
63+ assert session .backend ._use_caches
64+ assert backend ._cache .current_size () != 0
65+ validation_outputs = session .validate (
66+ "The answer should mention that there is a b in the middle of one of the strings but not the other." ,
67+ return_full_validation_results = True ,
68+ )
69+ assert len (validation_outputs ) == 1
70+ alora_output , valuation_boolean = validation_outputs [0 ]
71+ assert str (alora_output ) in ["Y" , "N" ]
72+
73+
74+ def test_constraint_lora_override (session , backend ):
75+ backend .default_to_constraint_checking_alora = False # type: ignore
76+ answer = session .instruct (
77+ "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
78+ )
79+ validation_outputs = session .validate (
80+ "The answer should mention that there is a b in the middle of one of the strings but not the other." ,
81+ return_full_validation_results = True ,
82+ )
83+ assert len (validation_outputs ) == 1
84+ non_alora_output , _ = validation_outputs [0 ]
85+ assert str (non_alora_output ) not in ["Y" , "N" ]
86+ backend .default_to_constraint_checking_alora = True
87+
88+
89+ def test_constraint_lora_override_does_not_override_alora (session , backend ):
90+ backend .default_to_constraint_checking_alora = False # type: ignore
91+ answer = session .instruct (
92+ "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
93+ )
94+ validation_outputs = session .validate (
95+ ALoraRequirement (
96+ "The answer should mention that there is a b in the middle of one of the strings but not the other."
97+ ),
98+ return_full_validation_results = True ,
99+ )
100+ assert len (validation_outputs ) == 1
101+ non_alora_output , _ = validation_outputs [0 ]
102+ assert str (non_alora_output ) in ["Y" , "N" ]
103+ backend .default_to_constraint_checking_alora = True
168104
169- results = self .m .backend ._generate_from_raw (
170- actions = [CBlock (value = prompt ) for prompt in prompts ], generate_logs = None
171- )
172105
173- assert len (results ) == len (prompts )
106+ def test_llmaj_req_does_not_use_alora (session , backend ):
107+ backend .default_to_constraint_checking_alora = True # type: ignore
108+ answer = session .instruct (
109+ "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa"
110+ )
111+ validation_outputs = session .validate (
112+ LLMaJRequirement (
113+ "The answer should mention that there is a b in the middle of one of the strings but not the other."
114+ ),
115+ return_full_validation_results = True ,
116+ )
117+ assert len (validation_outputs ) == 1
118+ non_alora_output , _ = validation_outputs [0 ]
119+ assert str (non_alora_output ) not in ["Y" , "N" ]
174120
175- def test_generate_from_raw_with_format (self ):
176- prompts = ["what is 1+1?" , "what is 2+2?" , "what is 3+3?" , "what is 4+4?" ]
177121
178- class Answer ( pydantic . BaseModel ):
179- name : str
180- value : int
122+ def test_instruct ( session ):
123+ result = session . instruct ( "Compute 1+1." )
124+ print ( result )
181125
182- results = self .m .backend ._generate_from_raw (
183- actions = [CBlock (value = prompt ) for prompt in prompts ],
184- format = Answer ,
185- generate_logs = None ,
186- )
187126
188- assert len (results ) == len (prompts )
127+ def test_multiturn (session ):
128+ session .instruct ("Compute 1+1" )
129+ beta = session .instruct (
130+ "Take the result of the previous sum and find the corresponding letter in the greek alphabet."
131+ )
132+ assert "β" in str (beta ).lower ()
133+ words = session .instruct (
134+ "Now list five English words that start with that letter."
135+ )
136+ print (words )
137+
138+
139+ def test_format (session ):
140+ class Person (pydantic .BaseModel ):
141+ name : str
142+ email_address : Annotated [
143+ str ,
144+ pydantic .StringConstraints (pattern = r"[a-zA-Z]{5,10}@example\.com" ),
145+ ]
146+
147+ class Email (pydantic .BaseModel ):
148+ to : Person
149+ subject : str
150+ body : str
151+
152+ output = session .instruct (
153+ "Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. " ,
154+ format = Email ,
155+ model_options = {ModelOption .MAX_NEW_TOKENS : 2 ** 8 },
156+ )
157+ print ("Formatted output:" )
158+ email = Email .model_validate_json (
159+ output .value
160+ ) # this should succeed because the output should be JSON because we passed in a format= argument...
161+ print (email )
162+
163+ print ("address:" , email .to .email_address )
164+ assert (
165+ "@" in email .to .email_address
166+ ), "The @ sign should be in the meail address."
167+ assert email .to .email_address .endswith (
168+ "example.com"
169+ ), "The email address should be at example.com"
170+
171+
172+ def test_generate_from_raw (session ):
173+ prompts = ["what is 1+1?" , "what is 2+2?" , "what is 3+3?" , "what is 4+4?" ]
174+
175+ results = session .backend ._generate_from_raw (
176+ actions = [CBlock (value = prompt ) for prompt in prompts ], generate_logs = None
177+ )
189178
190- random_result = results [0 ]
191- try :
192- answer = Answer .model_validate_json (random_result .value )
193- except pydantic .ValidationError as e :
194- assert (
195- False
196- ), f"formatting directive failed for { random_result .value } : { e .json ()} "
179+ assert len (results ) == len (prompts )
180+
181+
182+ def test_generate_from_raw_with_format (session ):
183+ prompts = ["what is 1+1?" , "what is 2+2?" , "what is 3+3?" , "what is 4+4?" ]
184+
185+ class Answer (pydantic .BaseModel ):
186+ name : str
187+ value : int
188+
189+ results = session .backend ._generate_from_raw (
190+ actions = [CBlock (value = prompt ) for prompt in prompts ],
191+ format = Answer ,
192+ generate_logs = None ,
193+ )
194+
195+ assert len (results ) == len (prompts )
196+
197+ random_result = results [0 ]
198+ try :
199+ answer = Answer .model_validate_json (random_result .value )
200+ except pydantic .ValidationError as e :
201+ assert (
202+ False
203+ ), f"formatting directive failed for { random_result .value } : { e .json ()} "
197204
198205
199206if __name__ == "__main__" :
200207 import pytest
201208
202- pytest .main ([__file__ ])
209+ pytest .main ([__file__ ])
0 commit comments