55from mellea .backends .model_ids import META_LLAMA_3_2_1B
66from mellea .backends .ollama import OllamaModelBackend
77from mellea .stdlib .base import ChatContext , Context
8- from mellea .stdlib .genslot import AsyncGenerativeSlot , GenerativeSlot , PreconditionException , SyncGenerativeSlot
8+ from mellea .stdlib .genslot import (
9+ AsyncGenerativeSlot ,
10+ GenerativeSlot ,
11+ PreconditionException ,
12+ SyncGenerativeSlot ,
13+ )
914from mellea .stdlib .requirement import Requirement , simple_validate
1015from mellea .stdlib .sampling .base import RejectionSamplingStrategy
1116from mellea .stdlib .session import MelleaSession
1217
18+
1319@pytest .fixture (scope = "module" )
1420def backend (gh_run : int ):
1521 """Shared backend."""
1622 if gh_run == 1 :
1723 return OllamaModelBackend (
18- model_id = META_LLAMA_3_2_1B .ollama_name , # type: ignore
24+ model_id = META_LLAMA_3_2_1B .ollama_name # type: ignore
1925 )
2026 else :
21- return OllamaModelBackend (
22- model_id = "granite3.3:8b" ,
23- )
27+ return OllamaModelBackend (model_id = "granite3.3:8b" )
28+
2429
2530@generative
2631def classify_sentiment (text : str ) -> Literal ["positive" , "negative" ]: ...
@@ -81,26 +86,66 @@ async def test_async_gen_slot(session):
8186 r1 = async_write_short_sentence (session , topic = "cats" )
8287 r2 = async_write_short_sentence (session , topic = "dogs" )
8388
84- r3 , c3 = await async_write_short_sentence (context = session .ctx , backend = session .backend , topic = "fish" )
89+ r3 , c3 = await async_write_short_sentence (
90+ context = session .ctx , backend = session .backend , topic = "fish"
91+ )
8592 results = await asyncio .gather (r1 , r2 )
8693
8794 assert isinstance (r3 , str )
8895 assert isinstance (c3 , Context )
8996 assert len (results ) == 2
9097
98+
9199@pytest .mark .parametrize (
92100 "arg_choices,kwarg_choices,errs" ,
93101 [
94102 pytest .param (["m" ], ["func1" , "func2" , "func3" ], False , id = "session" ),
95103 pytest .param (["context" ], ["backend" ], False , id = "context and backend" ),
96- pytest .param (["backend" ], ["func1" , "func2" , "func3" ], True , id = "backend without context" ),
104+ pytest .param (
105+ ["backend" ], ["func1" , "func2" , "func3" ], True , id = "backend without context"
106+ ),
97107 pytest .param (["m" ], ["m" ], True , id = "duplicate arg and kwarg" ),
98- pytest .param (["m" , "precondition_requirements" , "requirements" , "strategy" , "model_options" , "func1" , "func2" , "func3" ], [], True , id = "original func args as positional args" ),
99- pytest .param ([], ["m" , "func1" , "func2" , "func3" ], False , id = "session and func as kwargs" ),
100- pytest .param ([], ["m" , "precondition_requirements" , "requirements" , "strategy" , "model_options" , "func1" , "func2" , "func3" ], False , id = "all kwargs" ),
101- pytest .param ([], ["func1" , "m" , "func2" , "requirements" , "func3" ], False , id = "interspersed kwargs" ),
102- pytest .param ([], [], True , id = "missing required args" )
103- ]
108+ pytest .param (
109+ [
110+ "m" ,
111+ "precondition_requirements" ,
112+ "requirements" ,
113+ "strategy" ,
114+ "model_options" ,
115+ "func1" ,
116+ "func2" ,
117+ "func3" ,
118+ ],
119+ [],
120+ True ,
121+ id = "original func args as positional args" ,
122+ ),
123+ pytest .param (
124+ [], ["m" , "func1" , "func2" , "func3" ], False , id = "session and func as kwargs"
125+ ),
126+ pytest .param (
127+ [],
128+ [
129+ "m" ,
130+ "precondition_requirements" ,
131+ "requirements" ,
132+ "strategy" ,
133+ "model_options" ,
134+ "func1" ,
135+ "func2" ,
136+ "func3" ,
137+ ],
138+ False ,
139+ id = "all kwargs" ,
140+ ),
141+ pytest .param (
142+ [],
143+ ["func1" , "m" , "func2" , "requirements" , "func3" ],
144+ False ,
145+ id = "interspersed kwargs" ,
146+ ),
147+ pytest .param ([], [], True , id = "missing required args" ),
148+ ],
104149)
105150def test_arg_extraction (backend , arg_choices , kwarg_choices , errs ):
106151 """Tests the internal extract_args_and_kwargs function.
@@ -156,35 +201,40 @@ def test_arg_extraction(backend, arg_choices, kwarg_choices, errs):
156201 except Exception as e :
157202 found_err = True
158203 err = e
159-
204+
160205 if errs :
161206 assert found_err , "expected an exception and got none"
162207 else :
163208 assert not found_err , f"got unexpected err: { err } "
164209
210+
165211def test_disallowed_parameter_names ():
166212 with pytest .raises (ValueError ):
213+
167214 @generative
168- def test (backend ):
169- ...
215+ def test (backend ): ...
216+
170217
171218def test_precondition_failure (session ):
172219 with pytest .raises (PreconditionException ):
173220 classify_sentiment (
174221 m = session ,
175222 text = "hello" ,
176223 precondition_requirements = [
177- Requirement ("forced failure" , validation_fn = simple_validate (lambda x : (False , "" )))
178- ]
224+ Requirement (
225+ "forced failure" ,
226+ validation_fn = simple_validate (lambda x : (False , "" )),
227+ )
228+ ],
179229 )
180230
231+
181232def test_requirement (session ):
182233 classify_sentiment (
183- m = session ,
184- text = "hello" ,
185- requirements = ["req1" , "req2" , Requirement ("req3" )]
234+ m = session , text = "hello" , requirements = ["req1" , "req2" , Requirement ("req3" )]
186235 )
187236
237+
188238def test_with_no_args (session ):
189239 @generative
190240 def generate_text () -> str :
@@ -193,5 +243,6 @@ def generate_text() -> str:
193243
194244 generate_text (m = session )
195245
246+
196247if __name__ == "__main__" :
197248 pytest .main ([__file__ ])
0 commit comments