Skip to content

Commit c9bc4b5

Browse files
Keshav Ramji Keshav.Ramji@ibm.comKeshav Ramji Keshav.Ramji@ibm.com
authored andcommitted
Update v1 data format and judge call
1 parent c5f3d9f commit c9bc4b5

File tree

4 files changed

+93
-129
lines changed

4 files changed

+93
-129
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
kr_results/
33
kr_data/
44
xet/
5+
job.sh
6+
hub/
57

68
# Python-generated files
79
__pycache__/

cli/eval/commands.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ def eval_run(
99
),
1010
backend: str = typer.Option("ollama", "--backend", "-b", help="Generation backend"),
1111
model: str = typer.Option(None, "--model", help="Generation model name"),
12+
max_gen_tokens: int = typer.Option(256, "--max-gen-tokens", help="Max tokens to generate for responses"),
1213
judge_backend: str = typer.Option(
1314
None, "--judge-backend", "-jb", help="Judge backend"
1415
),
1516
judge_model: str = typer.Option(None, "--judge-model", help="Judge model name"),
17+
max_judge_tokens: int = typer.Option(256, "--max-judge-tokens", help="Max tokens for the judge model's judgement."),
1618
output_path: str = typer.Option(
1719
"eval_results", "--output-path", "-o", help="Output path for results"
1820
),
@@ -28,8 +30,10 @@ def eval_run(
2830
test_files=test_files,
2931
backend=backend,
3032
model=model,
33+
max_gen_tokens=max_gen_tokens,
3134
judge_backend=judge_backend,
3235
judge_model=judge_model,
36+
max_judge_tokens=max_judge_tokens,
3337
output_path=output_path,
3438
output_format=output_format,
3539
verbose=verbose,

cli/eval/runner.py

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55

66
import mellea
77
from mellea.stdlib.base import ModelOutputThunk
8-
from mellea.stdlib.requirement import Requirement
98
from mellea.stdlib.test_based_eval import TestBasedEval
109
from mellea.backends.types import ModelOption
1110

1211
from rich.console import Console
1312
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn
14-
from rich.table import Table
1513

1614
console = Console()
1715

@@ -25,7 +23,7 @@ def __init__(
2523
model_output: str,
2624
validation_passed: bool,
2725
score: int,
28-
validation_reason: str,
26+
validation_reason: str, # add input_id
2927
):
3028
self.input_text = input_text
3129
self.model_output = model_output
@@ -52,15 +50,15 @@ def __init__(self, test_eval: TestBasedEval, input_results: list[InputEvalResult
5250

5351
def to_dict(self):
5452
return {
55-
"conversation_id": self.test_eval.conversation_id,
56-
"category": self.test_eval.category,
53+
"test_id": self.test_eval.test_id,
54+
"source": self.test_eval.source,
55+
"name": self.test_eval.name,
56+
"instructions": self.test_eval.instructions,
5757
"input_results": [r.to_dict() for r in self.input_results],
5858
"expected_targets": self.test_eval.targets,
59-
"unit_test_instructions": self.test_eval.unit_test_instructions,
6059
"passed": self.passed_count,
6160
"total_count": self.total_count,
6261
"pass_rate": self.pass_rate,
63-
"metadata": self.test_eval.metadata,
6462
}
6563

6664
@property
@@ -76,7 +74,7 @@ def pass_rate(self) -> float:
7674
return self.passed_count / self.total_count if self.total_count > 0 else 0.0
7775

7876

79-
def create_session(backend: str, model: str | None) -> mellea.MelleaSession:
77+
def create_session(backend: str, model: str | None, max_tokens: int | None) -> mellea.MelleaSession:
8078
"""Create a mellea session with the specified backend and model."""
8179

8280
model_id = None
@@ -98,35 +96,35 @@ def create_session(backend: str, model: str | None) -> mellea.MelleaSession:
9896
from mellea.backends.ollama import OllamaModelBackend
9997

10098
backend_instance = OllamaModelBackend(
101-
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: 256}
99+
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}
102100
)
103101

104102
elif backend_lower == "openai":
105103
from mellea.backends.openai import OpenAIBackend
106104

107105
backend_instance = OpenAIBackend(
108-
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: 256}
106+
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}
109107
)
110108

111109
elif backend_lower in ["hf", "huggingface"]:
112110
from mellea.backends.huggingface import LocalHFBackend
113111

114112
backend_instance = LocalHFBackend(
115-
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: 256}
113+
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: max_tokens},
116114
)
117115

118116
elif backend_lower == "watsonx":
119117
from mellea.backends.watsonx import WatsonxAIBackend
120118

121119
backend_instance = WatsonxAIBackend(
122-
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: 256}
120+
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}
123121
)
124122

125123
elif backend_lower == "litellm":
126124
from mellea.backends.litellm import LiteLLMBackend
127125

128126
backend_instance = LiteLLMBackend(
129-
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: 256}
127+
model_id=model_id, model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}
130128
)
131129

132130
else:
@@ -139,7 +137,7 @@ def create_session(backend: str, model: str | None) -> mellea.MelleaSession:
139137

140138
session = mellea.MelleaSession(
141139
backend=backend_instance, ctx=SimpleContext()
142-
) # need to reset to SimpleContext? print what is being judged by the judge (input)
140+
)
143141
return session
144142

145143
except Exception as e:
@@ -153,8 +151,10 @@ def run_evaluations(
153151
test_files: List[str],
154152
backend: str,
155153
model: str | None,
154+
max_gen_tokens: int | None,
156155
judge_backend: str | None,
157156
judge_model: str | None,
157+
max_judge_tokens: int | None,
158158
output_path: str,
159159
output_format: str,
160160
verbose: bool,
@@ -176,16 +176,17 @@ def run_evaluations(
176176
return
177177

178178
console.print(f"Total test evals to run: {len(all_test_evals)}")
179+
total_inputs = sum(len(te.inputs) for te in all_test_evals)
180+
console.print(f"Total inputs to run: {total_inputs}")
179181

180182
console.print(f"Generation model: {model}")
181183
console.print(f"Judge model: {judge_model}")
182184

183-
m = create_session(backend=backend, model=model)
184-
judge_session = create_session(backend=judge_backend, model=judge_model)
185+
m = create_session(backend=backend, model=model, max_tokens=max_gen_tokens)
186+
judge_session = create_session(backend=judge_backend, model=judge_model, max_tokens=max_judge_tokens)
185187

186188
all_results = []
187189

188-
# some visuals on progress with rich, we can take out / modify
189190
with Progress(
190191
SpinnerColumn(),
191192
TextColumn("[progress.description]{task.description}"),
@@ -203,7 +204,7 @@ def run_evaluations(
203204
)
204205
all_results.append(result)
205206
except Exception as e:
206-
console.print(f"Error {e} on test {test_eval.conversation_id}")
207+
console.print(f"Error {e} on test {test_eval.test_id}")
207208
if not continue_on_error:
208209
raise
209210

@@ -229,23 +230,20 @@ def execute_test_eval(
229230
input_results = []
230231

231232
# for all inputs, generate responses with generator
232-
for input_text in test_eval.inputs:
233+
for idx, input_text in enumerate(test_eval.inputs):
233234
result: ModelOutputThunk = generation_session.act(input_text)
234235
model_output = str(result)
235-
console.print(model_output)
236236

237237
judge_session.ctx = judge_session.ctx.add(result)
238238

239-
requirement = Requirement(
240-
description=create_judge_requirement(test_eval, input_text, model_output)
241-
)
242-
validation_results = judge_session.validate(requirement)
243-
validation_result = validation_results[0]
239+
targets_for_input = (test_eval.targets[idx] if idx < len(test_eval.targets) else [])
244240

245-
judge_output = validation_result.reason or ""
241+
# query the judge
242+
judge_prompt = create_judge_requirement(test_eval, input_text, model_output, targets_for_input)
243+
judge_output_thunk = judge_session.act(judge_prompt)
244+
judge_output = str(judge_output_thunk)
246245
score, justification = parse_judge_output(judge_output)
247-
248-
passed = score == 1 if score is not None else validation_result.as_bool()
246+
passed = score == 1 if score is not None else False
249247

250248
input_result = InputEvalResult(
251249
input_text=input_text,
@@ -256,7 +254,7 @@ def execute_test_eval(
256254
)
257255
input_results.append(input_result)
258256

259-
# reset both generator and judge -- might not be necessary since SimpleContext doesn't retain history
257+
# reset both generator and judge
260258
generation_session.reset()
261259
judge_session.reset()
262260

@@ -265,24 +263,24 @@ def execute_test_eval(
265263

266264

267265
def create_judge_requirement(
268-
test_eval: TestBasedEval, input_text: str, model_output: str
266+
test_eval: TestBasedEval, input_text: str, model_output: str, targets_for_input: list[str]
269267
):
270268
"""Create judge requirement description"""
271269

272-
if len(test_eval.targets) == 0: # no reference
273-
target_text = "N/A" # another way to handle this?
274-
elif len(test_eval.targets) == 1:
275-
target_text = test_eval.targets[0]
276-
else: # enumerate the multiple targets
270+
if len(targets_for_input) == 0: # no reference
271+
target_text = "N/A"
272+
elif len(targets_for_input) == 1:
273+
target_text = targets_for_input[0]
274+
else: # enumerate when there are multiple targets
277275
target_text = "\n".join(
278-
[f"{i}. {target}" for i, target in enumerate(test_eval.targets, 1)]
276+
[f"{i}. {target}" for i, target in enumerate(targets_for_input, 1)]
279277
)
280278

281279
judge_prompt = test_eval.judge_prompt.format(
282280
input=input_text,
283281
prediction=model_output,
284282
target=target_text,
285-
guidelines=test_eval.unit_test_instructions,
283+
guidelines=test_eval.instructions,
286284
)
287285

288286
return judge_prompt
@@ -324,7 +322,7 @@ def save_results(results: List[TestEvalResult], output_path: str, output_format:
324322
f.write(json.dumps(result.to_dict()) + "\n")
325323
else: # json
326324
summary = {
327-
"total_unit_tests": len(results),
325+
"total_tests": len(results),
328326
"total_inputs": total_inputs,
329327
"passed_inputs": passed_inputs,
330328
"failed_inputs": total_inputs - passed_inputs,
@@ -348,11 +346,11 @@ def summary_stats(results: List[TestEvalResult]):
348346

349347
console.print(f"Total number of inputs across tests: {total_inputs}")
350348
console.print(f"Number of inputs passed across tests: {passed_inputs}")
351-
console.print(f"Cumulative Pass Rate: {overall_pass_rate}")
349+
console.print(f"Cumulative Pass Rate: {overall_pass_rate * 100:.1f}%")
352350

353351
if len(results) > 1:
354352
console.print("Per-Test Breakdown:")
355353
for result in results:
356354
console.print(
357-
f"{result.test_eval.conversation_id}:\n\t{result.passed_count}/{result.total_count} ({result.pass_rate * 100:.1f}%)\n\n"
355+
f"{result.test_eval.name}:\n\t{result.passed_count}/{result.total_count} ({result.pass_rate * 100:.1f}%)\n\n"
358356
)

0 commit comments

Comments
 (0)