Skip to content

Commit 90b8f49

Browse files
author
Keshav Ramji [email protected]
committed
PR update: jinja, pydantic
1 parent 6c7ad29 commit 90b8f49

File tree

5 files changed

+127
-78
lines changed

5 files changed

+127
-78
lines changed

cli/eval/commands.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""Use the eval command for LLM-as-a-judge evaluation, given a (set of) test file(s) consisting of prompts, instructions, and optionally, targets.
2+
Instantiate a generator model to produce candidate responses, and a judge model to determine whether the instructions have been followed."""
3+
14
import typer
25

36
eval_app = typer.Typer(name="eval")

cli/eval/runner.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
class InputEvalResult:
18-
"""Store results of a single input evaluation (within a unit test)"""
18+
"""Store results of a single input evaluation (within a unit test)."""
1919

2020
def __init__(
2121
self,
@@ -42,7 +42,7 @@ def to_dict(self):
4242

4343

4444
class TestEvalResult:
45-
"""Store results of a single test evaluation"""
45+
"""Store results of a single test evaluation."""
4646

4747
def __init__(self, test_eval: TestBasedEval, input_results: list[InputEvalResult]):
4848
self.test_eval = test_eval
@@ -77,7 +77,7 @@ def pass_rate(self) -> float:
7777
def create_session(
7878
backend: str, model: str | None, max_tokens: int | None
7979
) -> mellea.MelleaSession:
80-
"""Create a mellea session with the specified backend and model."""
80+
"""Create a mellea session with the specified backend and model."""
8181

8282
model_id = None
8383
if model:
@@ -164,7 +164,15 @@ def run_evaluations(
164164
output_format: str,
165165
continue_on_error: bool,
166166
):
167-
"""Run all 'unit test' evaluations"""
167+
"""Run all 'unit test' evaluations
168+
169+
Each test file should be a json containing:
170+
"id": an id that is unique to this test file
171+
"source": the origin for the evaluation prompts, else "N/A"
172+
"name": an instruction-following attribute that the user intends to evaluate through this test
173+
"instructions": a set (in string form) of requirements which the generation should follow; the judge will evaluate if these are satisfied
174+
"examples": a list of entries containing an input_id, an input(prompt), and a list of targets. Each input may have multiple (or no) targets; inputs and targets are in messages format.
175+
"""
168176
all_test_evals: List[TestBasedEval] = []
169177

170178
for test_file in test_files:
@@ -230,7 +238,7 @@ def execute_test_eval(
230238
) -> TestEvalResult:
231239
"""Execute a single test evaluation
232240
For each input in the test, generate a response using generation_session
233-
Then, after all inputs are processed, validate using judge_session
241+
Then, after all inputs are processed, validate using judge_session.
234242
"""
235243

236244
input_results = []
@@ -245,10 +253,12 @@ def execute_test_eval(
245253
)
246254

247255
# query the judge
248-
judge_prompt = create_judge_requirement(
249-
test_eval, input_text, model_output, targets_for_input
256+
test_eval.set_judge_context(
257+
input_text=input_text,
258+
prediction=model_output,
259+
targets_for_input=targets_for_input,
250260
)
251-
judge_output_thunk = judge_session.act(judge_prompt)
261+
judge_output_thunk = judge_session.act(test_eval)
252262
judge_output = str(judge_output_thunk)
253263
score, justification = parse_judge_output(judge_output)
254264
passed = score == 1 if score is not None else False
@@ -270,33 +280,6 @@ def execute_test_eval(
270280
return test_result
271281

272282

273-
def create_judge_requirement(
274-
test_eval: TestBasedEval,
275-
input_text: str,
276-
model_output: str,
277-
targets_for_input: list[str],
278-
):
279-
"""Create judge requirement description"""
280-
281-
if len(targets_for_input) == 0: # no reference
282-
target_text = "N/A"
283-
elif len(targets_for_input) == 1:
284-
target_text = targets_for_input[0]
285-
else: # enumerate when there are multiple targets
286-
target_text = "\n".join(
287-
[f"{i}. {target}" for i, target in enumerate(targets_for_input, 1)]
288-
)
289-
290-
judge_prompt = test_eval.judge_prompt.format(
291-
input=input_text,
292-
prediction=model_output,
293-
target=target_text,
294-
guidelines=test_eval.instructions,
295-
)
296-
297-
return judge_prompt
298-
299-
300283
def parse_judge_output(judge_output: str):
301284
try:
302285
json_match = re.search(r'\{[^}]*"score"[^}]*\}', judge_output, re.DOTALL)

mellea/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33
import mellea.backends.model_ids as model_ids
44
from mellea.stdlib.genslot import generative
55
from mellea.stdlib.session import MelleaSession, start_session
6-
from mellea.stdlib.test_based_eval import TestBasedEval
76

8-
__all__ = ["MelleaSession", "TestBasedEval", "generative", "model_ids", "start_session"]
7+
__all__ = ["MelleaSession", "generative", "model_ids", "start_session"]

mellea/stdlib/test_based_eval.py

Lines changed: 78 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,42 @@
44
from pathlib import Path
55
from typing import Any
66

7-
from mellea.stdlib.base import Component
7+
from pydantic import BaseModel, Field, field_validator
8+
9+
from mellea.stdlib.base import CBlock, Component, TemplateRepresentation
10+
11+
12+
class Message(BaseModel):
13+
"""Schema for a message in the test data."""
14+
15+
role: str
16+
content: str
17+
18+
19+
class Example(BaseModel):
20+
"""Schema for an example in the test data."""
21+
22+
input: list[Message]
23+
targets: list[Message] = Field(default_factory=list)
24+
input_id: str = ""
25+
26+
27+
class TestData(BaseModel):
28+
"""Schema for test data loaded from json."""
29+
30+
source: str
31+
name: str
32+
instructions: str
33+
examples: list[Example] = Field(default_factory=list)
34+
id: str
35+
36+
@field_validator("examples")
37+
@classmethod
38+
def validate_examples(cls, v):
39+
"""Ensure examples list is not empty."""
40+
if not v:
41+
raise ValueError("examples list cannot be empty")
42+
return v
843

944

1045
class TestBasedEval(Component):
@@ -29,32 +64,37 @@ def __init__(
2964
self.test_id = test_id
3065
self.input_ids = input_ids or []
3166

32-
self.judge_prompt = """**Input to the model**
33-
34-
{input}
35-
36-
**Model output to be rated**
37-
38-
{prediction}
39-
40-
**Ground truth text**
67+
def parts(self) -> list[Component | CBlock]:
68+
"""The set of constituent parts of the Component."""
69+
return []
4170

42-
{target}
71+
def format_for_llm(self) -> TemplateRepresentation:
72+
"""Formats the test for judge evaluation."""
73+
return TemplateRepresentation(
74+
obj=self,
75+
args=self._judge_context if hasattr(self, "_judge_context") else {},
76+
template_order=["*"],
77+
)
4378

44-
**Rating Guidelines**
45-
The model output should adhere to the following guidelines:
46-
{guidelines}
47-
48-
**Scoring Criteria**
49-
* Score 0: The model output violates any of the guidelines.
50-
* Score 1: The model output is well aligned with the ground truth - if it exists, the input to the model, and adheres to all guidelines.
51-
52-
**Return Your Rating**
53-
Return your rating in the following format:
54-
{{\"score\": your_score, \"justification\": \"your_justification\"}}
79+
def set_judge_context(
80+
self, input_text: str, prediction: str, targets_for_input: list[str]
81+
):
82+
"""Set context for judge evaluation."""
83+
if len(targets_for_input) == 0: # no reference
84+
target_text = "N/A"
85+
elif len(targets_for_input) == 1:
86+
target_text = targets_for_input[0]
87+
else: # enumerate when there are multiple targets
88+
target_text = "\n".join(
89+
[f"{i}. {target}" for i, target in enumerate(targets_for_input, 1)]
90+
)
5591

56-
Your rating:
57-
"""
92+
self._judge_context: dict[str, Any] = {
93+
"input": input_text,
94+
"prediction": prediction,
95+
"target": target_text,
96+
"guidelines": self.instructions,
97+
}
5898

5999
@classmethod
60100
def from_json_file(cls, filepath: str) -> list["TestBasedEval"]:
@@ -68,38 +108,35 @@ def from_json_file(cls, filepath: str) -> list["TestBasedEval"]:
68108
data = [data]
69109

70110
test_evals = []
71-
for test_data in data:
72-
examples = test_data.get("examples", [])
111+
for test_data_dict in data:
112+
try:
113+
test_data = TestData(**test_data_dict)
114+
except Exception as e:
115+
raise ValueError(f"Invalid test data in {filepath}: {e}")
73116

74117
inputs = []
75118
targets = []
76119
input_ids = []
77120

78-
for example in examples:
79-
input_messages = example.get("input", [])
80-
user_messages = [
81-
msg for msg in input_messages if msg.get("role") == "user"
82-
]
121+
for example in test_data.examples:
122+
user_messages = [msg for msg in example.input if msg.role == "user"]
83123
if user_messages:
84-
inputs.append(user_messages[-1].get("content", ""))
124+
inputs.append(user_messages[-1].content)
85125

86-
target_messages = example.get("targets", [])
87126
targets_for_input = [
88-
msg.get("content", "")
89-
for msg in target_messages
90-
if msg.get("role") == "assistant"
127+
msg.content for msg in example.targets if msg.role == "assistant"
91128
]
92129
targets.append(targets_for_input)
93130

94-
input_ids.append(example.get("input_id", ""))
131+
input_ids.append(example.input_id)
95132

96133
test_eval = cls(
97-
source=test_data.get("source", "unknown"),
98-
name=test_data.get("name", ""),
99-
instructions=test_data.get("instructions", ""),
134+
source=test_data.source,
135+
name=test_data.name,
136+
instructions=test_data.instructions,
100137
inputs=inputs,
101138
targets=targets,
102-
test_id=test_data.get("id", ""),
139+
test_id=test_data.id,
103140
input_ids=input_ids,
104141
)
105142
test_evals.append(test_eval)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
**Input to the model**
2+
3+
{{ input }}
4+
5+
**Model output to be rated**
6+
7+
{{ prediction }}
8+
9+
{% if target and target != "N/A" %}
10+
**Ground truth text**
11+
12+
{{ target }}
13+
{% endif %}
14+
15+
**Rating Guidelines**
16+
The model output should adhere to the following guidelines:
17+
{{ guidelines }}
18+
19+
**Scoring Criteria**
20+
* Score 0: The model output violates any of the guidelines.
21+
* Score 1: The model output is well aligned with the ground truth{% if target and target != "N/A" %} - if it exists{% endif %}, the input to the model, and adheres to all guidelines.
22+
23+
**Return Your Rating**
24+
Return your rating in the following format:
25+
{"score": your_score, "justification": "your_justification"}
26+
27+
Your rating:

0 commit comments

Comments
 (0)