@@ -23,7 +23,7 @@ def __init__(
2323 model_output : str ,
2424 validation_passed : bool ,
2525 score : int ,
26- validation_reason : str , # add input_id
26+ validation_reason : str , # add input_id
2727 ):
2828 self .input_text = input_text
2929 self .model_output = model_output
@@ -74,7 +74,9 @@ def pass_rate(self) -> float:
7474 return self .passed_count / self .total_count if self .total_count > 0 else 0.0
7575
7676
77- def create_session (backend : str , model : str | None , max_tokens : int | None ) -> mellea .MelleaSession :
77+ def create_session (
78+ backend : str , model : str | None , max_tokens : int | None
79+ ) -> mellea .MelleaSession :
7880 """Create a mellea session with the specified backend and model."""
7981
8082 model_id = None
@@ -96,35 +98,40 @@ def create_session(backend: str, model: str | None, max_tokens: int | None) -> m
9698 from mellea .backends .ollama import OllamaModelBackend
9799
98100 backend_instance = OllamaModelBackend (
99- model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens }
101+ model_id = model_id ,
102+ model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens },
100103 )
101104
102105 elif backend_lower == "openai" :
103106 from mellea .backends .openai import OpenAIBackend
104107
105108 backend_instance = OpenAIBackend (
106- model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens }
109+ model_id = model_id ,
110+ model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens },
107111 )
108112
109113 elif backend_lower in ["hf" , "huggingface" ]:
110114 from mellea .backends .huggingface import LocalHFBackend
111115
112116 backend_instance = LocalHFBackend (
113- model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens },
117+ model_id = model_id ,
118+ model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens },
114119 )
115120
116121 elif backend_lower == "watsonx" :
117122 from mellea .backends .watsonx import WatsonxAIBackend
118123
119124 backend_instance = WatsonxAIBackend (
120- model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens }
125+ model_id = model_id ,
126+ model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens },
121127 )
122128
123129 elif backend_lower == "litellm" :
124130 from mellea .backends .litellm import LiteLLMBackend
125131
126132 backend_instance = LiteLLMBackend (
127- model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens }
133+ model_id = model_id ,
134+ model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens },
128135 )
129136
130137 else :
@@ -135,9 +142,7 @@ def create_session(backend: str, model: str | None, max_tokens: int | None) -> m
135142 # create session with backend instance
136143 from mellea .stdlib .base import SimpleContext
137144
138- session = mellea .MelleaSession (
139- backend = backend_instance , ctx = SimpleContext ()
140- )
145+ session = mellea .MelleaSession (backend = backend_instance , ctx = SimpleContext ())
141146 return session
142147
143148 except Exception as e :
@@ -157,7 +162,6 @@ def run_evaluations(
157162 max_judge_tokens : int | None ,
158163 output_path : str ,
159164 output_format : str ,
160- verbose : bool ,
161165 continue_on_error : bool ,
162166):
163167 """Run all 'unit test' evaluations"""
@@ -176,14 +180,16 @@ def run_evaluations(
176180 return
177181
178182 console .print (f"Total test evals to run: { len (all_test_evals )} " )
179- total_inputs = sum (len (te .inputs ) for te in all_test_evals )
183+ total_inputs = sum (len (test_eval .inputs ) for test_eval in all_test_evals )
180184 console .print (f"Total inputs to run: { total_inputs } " )
181185
182186 console .print (f"Generation model: { model } " )
183187 console .print (f"Judge model: { judge_model } " )
184188
185189 m = create_session (backend = backend , model = model , max_tokens = max_gen_tokens )
186- judge_session = create_session (backend = judge_backend , model = judge_model , max_tokens = max_judge_tokens )
190+ judge_session = create_session (
191+ backend = judge_backend , model = judge_model , max_tokens = max_judge_tokens
192+ )
187193
188194 all_results = []
189195
@@ -234,12 +240,14 @@ def execute_test_eval(
234240 result : ModelOutputThunk = generation_session .act (input_text )
235241 model_output = str (result )
236242
237- judge_session . ctx = judge_session . ctx . add ( result )
238-
239- targets_for_input = ( test_eval . targets [ idx ] if idx < len ( test_eval . targets ) else [] )
243+ targets_for_input = (
244+ test_eval . targets [ idx ] if idx < len ( test_eval . targets ) else []
245+ )
240246
241247 # query the judge
242- judge_prompt = create_judge_requirement (test_eval , input_text , model_output , targets_for_input )
248+ judge_prompt = create_judge_requirement (
249+ test_eval , input_text , model_output , targets_for_input
250+ )
243251 judge_output_thunk = judge_session .act (judge_prompt )
244252 judge_output = str (judge_output_thunk )
245253 score , justification = parse_judge_output (judge_output )
@@ -263,7 +271,10 @@ def execute_test_eval(
263271
264272
265273def create_judge_requirement (
266- test_eval : TestBasedEval , input_text : str , model_output : str , targets_for_input : list [str ]
274+ test_eval : TestBasedEval ,
275+ input_text : str ,
276+ model_output : str ,
277+ targets_for_input : list [str ],
267278):
268279 """Create judge requirement description"""
269280
0 commit comments