Skip to content

Commit fc0a946

Browse files
authored
feat: add automatic multimodal scoring to llm_judge scorer (#302)
Enables llm_judge to automatically detect and score Message outputs containing images and audio alongside text. When a Message with images/audio is provided, they are automatically included in the evaluation using vision-capable models. Key changes: - Automatic multimodal detection via Message.image_parts/audio_parts - Zero API changes - backward compatible with text-only scoring - Single combined score for text + images + audio - Extract helper functions to improve code quality - Add observability attributes (has_multimodal, num_images, num_audio) - Example notebook demonstrating text-only, image-only, and multimodal scoring
1 parent fb8ba04 commit fc0a946

File tree

2 files changed

+378
-28
lines changed

2 files changed

+378
-28
lines changed

dreadnode/scorers/judge.py

Lines changed: 129 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import typing as t
23

34
import rigging as rg
@@ -8,6 +9,99 @@
89
from dreadnode.metric import Metric
910
from dreadnode.scorers import Scorer
1011

12+
if t.TYPE_CHECKING:
13+
from dreadnode.data_types.message import Message
14+
15+
16+
def _build_multimodal_content(
17+
data: "Message", output_text: str, rubric: str
18+
) -> list[rg.ContentText | rg.ContentImageUrl | rg.ContentAudioInput]:
19+
"""Build rigging content parts from Message with images/audio."""
20+
rg_content: list[rg.ContentText | rg.ContentImageUrl | rg.ContentAudioInput] = [
21+
rg.ContentText(text=f"Output: {output_text}\n\nRubric: {rubric}")
22+
]
23+
24+
# Add images
25+
for img in data.image_parts:
26+
base64_str = img.to_base64()
27+
_, meta = img.to_serializable()
28+
img_format = meta.get("format", "png")
29+
data_url = f"data:image/{img_format};base64,{base64_str}"
30+
rg_content.append(rg.ContentImageUrl.from_url(data_url))
31+
32+
# Add audio
33+
for audio in data.audio_parts:
34+
audio_bytes, audio_meta = audio.to_serializable()
35+
audio_base64 = base64.b64encode(audio_bytes).decode()
36+
audio_format = audio_meta.get("extension", "wav")
37+
rg_content.append(
38+
rg.ContentAudioInput.from_bytes(
39+
base64.b64decode(audio_base64),
40+
format=audio_format,
41+
)
42+
)
43+
44+
return rg_content
45+
46+
47+
def _create_judge_pipeline(
48+
generator: rg.Generator,
49+
data: "Message",
50+
output_text: str,
51+
rubric: str,
52+
system_prompt: str | None,
53+
*,
54+
has_multimodal: bool,
55+
) -> rg.ChatPipeline:
56+
"""Create judge pipeline with optional multimodal content."""
57+
if has_multimodal:
58+
rg_content = _build_multimodal_content(data, output_text, rubric)
59+
user_message = rg.Message(role="user", content=rg_content)
60+
pipeline = generator.chat([user_message])
61+
else:
62+
pipeline = generator.chat([])
63+
64+
if system_prompt:
65+
pipeline.chat.inject_system_content(system_prompt)
66+
67+
return pipeline
68+
69+
70+
def _create_judge_metrics(
71+
judgement: "Judgement",
72+
*,
73+
has_multimodal: bool,
74+
num_images: int,
75+
num_audio: int,
76+
used_fallback: bool,
77+
fallback_model: str | rg.Generator | None,
78+
primary_error: str | None,
79+
name: str,
80+
) -> list[Metric]:
81+
"""Create score and pass metrics from judgement."""
82+
score_metric = Metric(
83+
value=judgement.score,
84+
attributes={
85+
"reason": judgement.reason,
86+
"has_multimodal": has_multimodal,
87+
"num_images": num_images,
88+
"num_audio": num_audio,
89+
"used_fallback": used_fallback,
90+
"fallback_model": (
91+
str(fallback_model)
92+
if isinstance(fallback_model, str)
93+
else type(fallback_model).__name__
94+
)
95+
if used_fallback
96+
else None,
97+
"primary_error": primary_error,
98+
},
99+
)
100+
pass_metric = Metric(value=float(judgement.passing))
101+
pass_metric._scorer_name = f"{name}_pass" # type: ignore[attr-defined] # noqa: SLF001
102+
103+
return [score_metric, pass_metric]
104+
11105

12106
class JudgeInput(rg.Model):
13107
input: str | None = rg.element(default=None)
@@ -31,6 +125,7 @@ def judge(input: JudgeInput) -> Judgement: # type: ignore [empty-body]
31125
"""
32126
You are grading output according to a user-specified rubric.
33127
128+
Evaluate all provided content (text, images, audio, etc.) when making your judgment.
34129
If the statement in the rubric is true for the provided input and output, then the output passes the test.
35130
Assign a score based on the rubric, where applicable, otherwise 1.0 for passing and 0.0 for failing.
36131
"""
@@ -53,8 +148,12 @@ def llm_judge(
53148
"""
54149
Score the output of a task using an LLM to judge it against a rubric.
55150
151+
Automatically handles multimodal outputs (text + images + audio). When the output is a Message
152+
containing images or audio, they will be included in the evaluation. Use vision-capable models
153+
(e.g., "gpt-4o") when scoring multimodal content.
154+
56155
Args:
57-
model: The model to use for judging.
156+
model: The model to use for judging. Use vision-capable models for multimodal outputs.
58157
rubric: The rubric to use for judging.
59158
input: The input which produced the output for context, if applicable.
60159
expected_output: The expected output to compare against, if applicable.
@@ -102,22 +201,33 @@ def _create_generator(
102201

103202
generator = _create_generator(model, model_params)
104203

204+
# Check if data is a multimodal Message
205+
from dreadnode.data_types.message import Message
206+
207+
is_message = isinstance(data, Message)
208+
has_multimodal = is_message and bool(data.image_parts or data.audio_parts)
209+
210+
# Extract text output
211+
output_text = data.text if is_message else str(data)
212+
105213
input_data = JudgeInput(
106214
input=str(input) if input is not None else None,
107215
expected_output=str(expected_output) if expected_output is not None else None,
108-
output=str(data),
216+
output=output_text,
109217
rubric=rubric,
110218
)
111219

112-
# Track fallback usage for observability
220+
# Track fallback usage and multimodal content for observability
113221
used_fallback = False
114222
primary_error: str | None = None
223+
num_images = len(data.image_parts) if has_multimodal else 0
224+
num_audio = len(data.audio_parts) if has_multimodal else 0
115225

116226
# Try primary model, fallback if needed
117227
try:
118-
pipeline = generator.chat([])
119-
if system_prompt:
120-
pipeline.chat.inject_system_content(system_prompt)
228+
pipeline = _create_judge_pipeline(
229+
generator, data, output_text, rubric, system_prompt, has_multimodal=has_multimodal
230+
)
121231
judgement = await judge.bind(pipeline)(input_data)
122232
except Exception as e:
123233
if fallback_model is None:
@@ -133,11 +243,11 @@ def _create_generator(
133243
f"Primary model '{primary_model_name}' failed with {primary_error}. "
134244
f"Using fallback model '{fallback_model_name}'."
135245
)
136-
# Use fallback model
246+
# Use fallback model with same multimodal content
137247
generator = _create_generator(fallback_model, model_params)
138-
pipeline = generator.chat([])
139-
if system_prompt:
140-
pipeline.chat.inject_system_content(system_prompt)
248+
pipeline = _create_judge_pipeline(
249+
generator, data, output_text, rubric, system_prompt, has_multimodal=has_multimodal
250+
)
141251
judgement = await judge.bind(pipeline)(input_data)
142252

143253
if min_score is not None:
@@ -148,24 +258,15 @@ def _create_generator(
148258
if passing is not None:
149259
judgement.passing = passing(judgement.score)
150260

151-
score_metric = Metric(
152-
value=judgement.score,
153-
attributes={
154-
"reason": judgement.reason,
155-
"used_fallback": used_fallback,
156-
"fallback_model": (
157-
str(fallback_model)
158-
if isinstance(fallback_model, str)
159-
else type(fallback_model).__name__
160-
)
161-
if used_fallback
162-
else None,
163-
"primary_error": primary_error,
164-
},
261+
return _create_judge_metrics(
262+
judgement,
263+
has_multimodal=has_multimodal,
264+
num_images=num_images,
265+
num_audio=num_audio,
266+
used_fallback=used_fallback,
267+
fallback_model=fallback_model,
268+
primary_error=primary_error,
269+
name=name,
165270
)
166-
pass_metric = Metric(value=float(judgement.passing))
167-
pass_metric._scorer_name = f"{name}_pass" # type: ignore[attr-defined] # noqa: SLF001
168-
169-
return [score_metric, pass_metric]
170271

171272
return Scorer(evaluate, name=name)

0 commit comments

Comments
 (0)