55
66import mellea
77from mellea .stdlib .base import ModelOutputThunk
8- from mellea .stdlib .requirement import Requirement
98from mellea .stdlib .test_based_eval import TestBasedEval
109from mellea .backends .types import ModelOption
1110
1211from rich .console import Console
1312from rich .progress import BarColumn , Progress , SpinnerColumn , TextColumn
14- from rich .table import Table
1513
1614console = Console ()
1715
@@ -25,7 +23,7 @@ def __init__(
2523 model_output : str ,
2624 validation_passed : bool ,
2725 score : int ,
28- validation_reason : str ,
26+ validation_reason : str , # add input_id
2927 ):
3028 self .input_text = input_text
3129 self .model_output = model_output
@@ -52,15 +50,15 @@ def __init__(self, test_eval: TestBasedEval, input_results: list[InputEvalResult
5250
5351 def to_dict (self ):
5452 return {
55- "conversation_id" : self .test_eval .conversation_id ,
56- "category" : self .test_eval .category ,
53+ "test_id" : self .test_eval .test_id ,
54+ "source" : self .test_eval .source ,
55+ "name" : self .test_eval .name ,
56+ "instructions" : self .test_eval .instructions ,
5757 "input_results" : [r .to_dict () for r in self .input_results ],
5858 "expected_targets" : self .test_eval .targets ,
59- "unit_test_instructions" : self .test_eval .unit_test_instructions ,
6059 "passed" : self .passed_count ,
6160 "total_count" : self .total_count ,
6261 "pass_rate" : self .pass_rate ,
63- "metadata" : self .test_eval .metadata ,
6462 }
6563
6664 @property
@@ -76,7 +74,7 @@ def pass_rate(self) -> float:
7674 return self .passed_count / self .total_count if self .total_count > 0 else 0.0
7775
7876
79- def create_session (backend : str , model : str | None ) -> mellea .MelleaSession :
77+ def create_session (backend : str , model : str | None , max_tokens : int | None ) -> mellea .MelleaSession :
8078 """Create a mellea session with the specified backend and model."""
8179
8280 model_id = None
@@ -98,35 +96,35 @@ def create_session(backend: str, model: str | None) -> mellea.MelleaSession:
9896 from mellea .backends .ollama import OllamaModelBackend
9997
10098 backend_instance = OllamaModelBackend (
101- model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : 256 }
99+ model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens }
102100 )
103101
104102 elif backend_lower == "openai" :
105103 from mellea .backends .openai import OpenAIBackend
106104
107105 backend_instance = OpenAIBackend (
108- model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : 256 }
106+ model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens }
109107 )
110108
111109 elif backend_lower in ["hf" , "huggingface" ]:
112110 from mellea .backends .huggingface import LocalHFBackend
113111
114112 backend_instance = LocalHFBackend (
115- model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : 256 }
113+ model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens },
116114 )
117115
118116 elif backend_lower == "watsonx" :
119117 from mellea .backends .watsonx import WatsonxAIBackend
120118
121119 backend_instance = WatsonxAIBackend (
122- model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : 256 }
120+ model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens }
123121 )
124122
125123 elif backend_lower == "litellm" :
126124 from mellea .backends .litellm import LiteLLMBackend
127125
128126 backend_instance = LiteLLMBackend (
129- model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : 256 }
127+ model_id = model_id , model_options = {ModelOption .MAX_NEW_TOKENS : max_tokens }
130128 )
131129
132130 else :
@@ -139,7 +137,7 @@ def create_session(backend: str, model: str | None) -> mellea.MelleaSession:
139137
140138 session = mellea .MelleaSession (
141139 backend = backend_instance , ctx = SimpleContext ()
142- ) # need to reset to SimpleContext? print what is being judged by the judge (input)
140+ )
143141 return session
144142
145143 except Exception as e :
@@ -153,8 +151,10 @@ def run_evaluations(
153151 test_files : List [str ],
154152 backend : str ,
155153 model : str | None ,
154+ max_gen_tokens : int | None ,
156155 judge_backend : str | None ,
157156 judge_model : str | None ,
157+ max_judge_tokens : int | None ,
158158 output_path : str ,
159159 output_format : str ,
160160 verbose : bool ,
@@ -176,16 +176,17 @@ def run_evaluations(
176176 return
177177
178178 console .print (f"Total test evals to run: { len (all_test_evals )} " )
179+ total_inputs = sum (len (te .inputs ) for te in all_test_evals )
180+ console .print (f"Total inputs to run: { total_inputs } " )
179181
180182 console .print (f"Generation model: { model } " )
181183 console .print (f"Judge model: { judge_model } " )
182184
183- m = create_session (backend = backend , model = model )
184- judge_session = create_session (backend = judge_backend , model = judge_model )
185+ m = create_session (backend = backend , model = model , max_tokens = max_gen_tokens )
186+ judge_session = create_session (backend = judge_backend , model = judge_model , max_tokens = max_judge_tokens )
185187
186188 all_results = []
187189
188- # some visuals on progress with rich, we can take out / modify
189190 with Progress (
190191 SpinnerColumn (),
191192 TextColumn ("[progress.description]{task.description}" ),
@@ -203,7 +204,7 @@ def run_evaluations(
203204 )
204205 all_results .append (result )
205206 except Exception as e :
206- console .print (f"Error { e } on test { test_eval .conversation_id } " )
207+ console .print (f"Error { e } on test { test_eval .test_id } " )
207208 if not continue_on_error :
208209 raise
209210
@@ -229,23 +230,20 @@ def execute_test_eval(
229230 input_results = []
230231
231232 # for all inputs, generate responses with generator
232- for input_text in test_eval .inputs :
233+ for idx , input_text in enumerate ( test_eval .inputs ) :
233234 result : ModelOutputThunk = generation_session .act (input_text )
234235 model_output = str (result )
235- console .print (model_output )
236236
237237 judge_session .ctx = judge_session .ctx .add (result )
238238
239- requirement = Requirement (
240- description = create_judge_requirement (test_eval , input_text , model_output )
241- )
242- validation_results = judge_session .validate (requirement )
243- validation_result = validation_results [0 ]
239+ targets_for_input = (test_eval .targets [idx ] if idx < len (test_eval .targets ) else [])
244240
245- judge_output = validation_result .reason or ""
241+ # query the judge
242+ judge_prompt = create_judge_requirement (test_eval , input_text , model_output , targets_for_input )
243+ judge_output_thunk = judge_session .act (judge_prompt )
244+ judge_output = str (judge_output_thunk )
246245 score , justification = parse_judge_output (judge_output )
247-
248- passed = score == 1 if score is not None else validation_result .as_bool ()
246+ passed = score == 1 if score is not None else False
249247
250248 input_result = InputEvalResult (
251249 input_text = input_text ,
@@ -256,7 +254,7 @@ def execute_test_eval(
256254 )
257255 input_results .append (input_result )
258256
259- # reset both generator and judge -- might not be necessary since SimpleContext doesn't retain history
257+ # reset both generator and judge
260258 generation_session .reset ()
261259 judge_session .reset ()
262260
@@ -265,24 +263,24 @@ def execute_test_eval(
265263
266264
267265def create_judge_requirement (
268- test_eval : TestBasedEval , input_text : str , model_output : str
266+ test_eval : TestBasedEval , input_text : str , model_output : str , targets_for_input : list [ str ]
269267):
270268 """Create judge requirement description"""
271269
272- if len (test_eval . targets ) == 0 : # no reference
273- target_text = "N/A" # another way to handle this?
274- elif len (test_eval . targets ) == 1 :
275- target_text = test_eval . targets [0 ]
276- else : # enumerate the multiple targets
270+ if len (targets_for_input ) == 0 : # no reference
271+ target_text = "N/A"
272+ elif len (targets_for_input ) == 1 :
273+ target_text = targets_for_input [0 ]
274+ else : # enumerate when there are multiple targets
277275 target_text = "\n " .join (
278- [f"{ i } . { target } " for i , target in enumerate (test_eval . targets , 1 )]
276+ [f"{ i } . { target } " for i , target in enumerate (targets_for_input , 1 )]
279277 )
280278
281279 judge_prompt = test_eval .judge_prompt .format (
282280 input = input_text ,
283281 prediction = model_output ,
284282 target = target_text ,
285- guidelines = test_eval .unit_test_instructions ,
283+ guidelines = test_eval .instructions ,
286284 )
287285
288286 return judge_prompt
@@ -324,7 +322,7 @@ def save_results(results: List[TestEvalResult], output_path: str, output_format:
324322 f .write (json .dumps (result .to_dict ()) + "\n " )
325323 else : # json
326324 summary = {
327- "total_unit_tests " : len (results ),
325+ "total_tests " : len (results ),
328326 "total_inputs" : total_inputs ,
329327 "passed_inputs" : passed_inputs ,
330328 "failed_inputs" : total_inputs - passed_inputs ,
@@ -348,11 +346,11 @@ def summary_stats(results: List[TestEvalResult]):
348346
349347 console .print (f"Total number of inputs across tests: { total_inputs } " )
350348 console .print (f"Number of inputs passed across tests: { passed_inputs } " )
351- console .print (f"Cumulative Pass Rate: { overall_pass_rate } " )
349+ console .print (f"Cumulative Pass Rate: { overall_pass_rate * 100 :.1f } % " )
352350
353351 if len (results ) > 1 :
354352 console .print ("Per-Test Breakdown:" )
355353 for result in results :
356354 console .print (
357- f"{ result .test_eval .conversation_id } :\n \t { result .passed_count } /{ result .total_count } ({ result .pass_rate * 100 :.1f} %)\n \n "
355+ f"{ result .test_eval .name } :\n \t { result .passed_count } /{ result .total_count } ({ result .pass_rate * 100 :.1f} %)\n \n "
358356 )
0 commit comments