Skip to content

Commit f525ee2

Browse files
authored
FIX/FEAT: Enable multi-modal pieces for SelfAskTrueFalseScorer scoring (Azure#1287)
1 parent aec14c4 commit f525ee2

File tree

6 files changed

+187
-24
lines changed

6 files changed

+187
-24
lines changed

doc/code/scoring/scorer_evals.ipynb

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
")\n",
3939
"from pyrit.setup import IN_MEMORY, initialize_pyrit_async\n",
4040
"\n",
41-
"await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore\n",
42-
"target = OpenAIChatTarget()"
41+
"await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore"
4342
]
4443
},
4544
{
@@ -123,7 +122,14 @@
123122
}
124123
],
125124
"source": [
126-
"target = OpenAIChatTarget()\n",
125+
"import os\n",
126+
"\n",
127+
"# Use unsafe endpoint ideally since evaluation dataset may include harmful content\n",
128+
"target = OpenAIChatTarget(\n",
129+
" endpoint=os.environ[\"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT\"],\n",
130+
" api_key=os.environ[\"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY\"],\n",
131+
" model_name=os.environ[\"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL\"],\n",
132+
")\n",
127133
"likert_scorer = SelfAskLikertScorer(chat_target=target, likert_scale_path=LikertScalePaths.HATE_SPEECH_SCALE.value)\n",
128134
"\n",
129135
"# factory method that creates an HarmScorerEvaluator in this case since metrics_type is HARM.\n",

doc/code/scoring/scorer_evals.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55
# extension: .py
66
# format_name: percent
77
# format_version: '1.3'
8-
# jupytext_version: 1.18.1
9-
# kernelspec:
10-
# display_name: pyrit2
11-
# language: python
12-
# name: python3
8+
# jupytext_version: 1.17.2
139
# ---
1410

1511
# %% [markdown]
@@ -40,7 +36,6 @@
4036
from pyrit.setup import IN_MEMORY, initialize_pyrit_async
4137

4238
await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore
43-
target = OpenAIChatTarget()
4439

4540
# %% [markdown]
4641
# ## Running Harm Scorer Evaluation
@@ -80,7 +75,14 @@
8075
# With multiple evaluators, we can measure inter-reliability alignment between evaluators shown below:
8176

8277
# %%
83-
target = OpenAIChatTarget()
78+
import os
79+
80+
# Use unsafe endpoint ideally since evaluation dataset may include harmful content
81+
target = OpenAIChatTarget(
82+
endpoint=os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT"],
83+
api_key=os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY"],
84+
model_name=os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"],
85+
)
8486
likert_scorer = SelfAskLikertScorer(chat_target=target, likert_scale_path=LikertScalePaths.HATE_SPEECH_SCALE.value)
8587

8688
# factory method that creates an HarmScorerEvaluator in this case since metrics_type is HARM.

pyrit/score/float_scale/self_ask_scale_scorer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def _build_scorer_identifier(self) -> None:
8787
"""Build the scorer evaluation identifier for this scorer."""
8888
self._set_scorer_identifier(
8989
system_prompt_template=self._system_prompt,
90+
user_prompt_template="objective: {objective}\nresponse: {response}",
9091
prompt_target=self._prompt_target,
9192
)
9293

pyrit/score/scorer.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ async def _score_value_with_llm(
460460
message_value: str,
461461
message_data_type: PromptDataType,
462462
scored_prompt_id: str,
463+
prepended_text_message_piece: Optional[str] = None,
463464
category: Optional[Sequence[str] | str] = None,
464465
objective: Optional[str] = None,
465466
score_value_output_key: str = "score_value",
@@ -478,9 +479,15 @@ async def _score_value_with_llm(
478479
Args:
479480
prompt_target (PromptChatTarget): The target LLM to send the message to.
480481
system_prompt (str): The system-level prompt that guides the behavior of the target LLM.
481-
message_value (str): The actual value or content to be scored by the LLM.
482-
message_data_type (PromptDataType): The type of the data being sent in the message.
482+
message_value (str): The actual value or content to be scored by the LLM (e.g., text, image path,
483+
audio path).
484+
message_data_type (PromptDataType): The type of the data being sent in the message (e.g., "text",
485+
"image_path", "audio_path").
483486
scored_prompt_id (str): The ID of the scored prompt.
487+
prepended_text_message_piece (Optional[str]): Text context to prepend before the main
488+
message_value. When provided, creates a multi-piece message with this text first, followed
489+
by the message_value. Useful for adding objective/context when scoring non-text content.
490+
Defaults to None.
484491
category (Optional[Sequence[str] | str]): The category of the score. Can also be parsed from
485492
the JSON response if not provided. Defaults to None.
486493
objective (Optional[str]): A description of the objective that is associated with the score,
@@ -518,19 +525,38 @@ async def _score_value_with_llm(
518525
attack_identifier=attack_identifier,
519526
)
520527
prompt_metadata: dict[str, str | int] = {"response_format": "json"}
521-
scorer_llm_request = Message(
522-
[
528+
529+
# Build message pieces - prepended text context first (if provided), then the main message being scored
530+
message_pieces: list[MessagePiece] = []
531+
532+
# Add prepended text context piece if provided (e.g., objective context for non-text scoring)
533+
if prepended_text_message_piece:
534+
message_pieces.append(
523535
MessagePiece(
524536
role="user",
525-
original_value=message_value,
526-
original_value_data_type=message_data_type,
527-
converted_value_data_type=message_data_type,
537+
original_value=prepended_text_message_piece,
538+
original_value_data_type="text",
539+
converted_value_data_type="text",
528540
conversation_id=conversation_id,
529541
prompt_target_identifier=prompt_target.get_identifier(),
530542
prompt_metadata=prompt_metadata,
531543
)
532-
]
544+
)
545+
546+
# Add the main message piece being scored
547+
message_pieces.append(
548+
MessagePiece(
549+
role="user",
550+
original_value=message_value,
551+
original_value_data_type=message_data_type,
552+
converted_value_data_type=message_data_type,
553+
conversation_id=conversation_id,
554+
prompt_target_identifier=prompt_target.get_identifier(),
555+
prompt_metadata=prompt_metadata,
556+
)
533557
)
558+
559+
scorer_llm_request = Message(message_pieces)
534560
try:
535561
response = await prompt_target.send_prompt_async(message=scorer_llm_request)
536562
except Exception as ex:

pyrit/score/true_false/self_ask_true_false_scorer.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pyrit.common import verify_and_resolve_path
1111
from pyrit.common.path import SCORER_SEED_PROMPT_PATH
12-
from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore
12+
from pyrit.models import MessagePiece, Score, SeedPrompt
1313
from pyrit.prompt_target import PromptChatTarget
1414
from pyrit.score.scorer_prompt_validator import ScorerPromptValidator
1515
from pyrit.score.true_false.true_false_score_aggregator import (
@@ -150,6 +150,7 @@ def _build_scorer_identifier(self) -> None:
150150
"""Build the scorer evaluation identifier for this scorer."""
151151
self._set_scorer_identifier(
152152
system_prompt_template=self._system_prompt,
153+
user_prompt_template="objective: {objective}\nresponse: {response}",
153154
prompt_target=self._prompt_target,
154155
score_aggregator=self._score_aggregator.__name__,
155156
)
@@ -169,14 +170,24 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op
169170
The score_value is True or False based on which description fits best.
170171
Metadata can be configured to provide additional information.
171172
"""
172-
scoring_prompt = f"objective: {objective}\nresponse: {message_piece.converted_value}"
173-
174-
unvalidated_score: UnvalidatedScore = await self._score_value_with_llm(
173+
# Build scoring prompt - for non-text content, extra context about objective is sent as a prepended text piece
174+
is_non_text = message_piece.converted_value_data_type != "text"
175+
if is_non_text:
176+
prepended_text = f"objective: {objective}\nresponse:"
177+
scoring_value = message_piece.converted_value
178+
scoring_data_type = message_piece.converted_value_data_type
179+
else:
180+
prepended_text = None
181+
scoring_value = f"objective: {objective}\nresponse: {message_piece.converted_value}"
182+
scoring_data_type = "text"
183+
184+
unvalidated_score = await self._score_value_with_llm(
175185
prompt_target=self._prompt_target,
176186
system_prompt=self._system_prompt,
177-
message_value=scoring_prompt,
178-
message_data_type=message_piece.converted_value_data_type,
187+
message_value=scoring_value,
188+
message_data_type=scoring_data_type,
179189
scored_prompt_id=message_piece.id,
190+
prepended_text_message_piece=prepended_text,
180191
category=self._score_category,
181192
objective=objective,
182193
attack_identifier=message_piece.attack_identifier,

tests/unit/score/test_scorer.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,123 @@ async def test_scorer_remove_markdown_json_called(good_json):
303303
mock_remove_markdown_json.assert_called_once()
304304

305305

306+
@pytest.mark.asyncio
307+
async def test_score_value_with_llm_prepended_text_message_piece_creates_multipiece_message(good_json):
308+
"""Test that prepended_text_message_piece creates a multi-piece message (text context + main content)."""
309+
chat_target = MagicMock(PromptChatTarget)
310+
good_json_resp = Message(
311+
message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")]
312+
)
313+
chat_target.send_prompt_async = AsyncMock(return_value=[good_json_resp])
314+
315+
scorer = MockScorer()
316+
317+
await scorer._score_value_with_llm(
318+
prompt_target=chat_target,
319+
system_prompt="system_prompt",
320+
message_value="test_image.png",
321+
message_data_type="image_path",
322+
scored_prompt_id="123",
323+
prepended_text_message_piece="objective: test\nresponse:",
324+
category="category",
325+
objective="task",
326+
)
327+
328+
# Verify send_prompt_async was called
329+
chat_target.send_prompt_async.assert_called_once()
330+
331+
# Get the message that was sent
332+
call_args = chat_target.send_prompt_async.call_args
333+
sent_message = call_args.kwargs["message"]
334+
335+
# Should have 2 pieces: text context first, then the main content being scored
336+
assert len(sent_message.message_pieces) == 2
337+
338+
# First piece should be the extra text context
339+
text_piece = sent_message.message_pieces[0]
340+
assert text_piece.converted_value_data_type == "text"
341+
assert "objective: test" in text_piece.original_value
342+
343+
# Second piece should be the main content (image in this case)
344+
main_piece = sent_message.message_pieces[1]
345+
assert main_piece.converted_value_data_type == "image_path"
346+
assert main_piece.original_value == "test_image.png"
347+
348+
349+
@pytest.mark.asyncio
350+
async def test_score_value_with_llm_no_prepended_text_creates_single_piece_message(good_json):
351+
"""Test that without prepended_text_message_piece, only a single piece message is created."""
352+
chat_target = MagicMock(PromptChatTarget)
353+
good_json_resp = Message(
354+
message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")]
355+
)
356+
chat_target.send_prompt_async = AsyncMock(return_value=[good_json_resp])
357+
358+
scorer = MockScorer()
359+
360+
await scorer._score_value_with_llm(
361+
prompt_target=chat_target,
362+
system_prompt="system_prompt",
363+
message_value="objective: test\nresponse: some text",
364+
message_data_type="text",
365+
scored_prompt_id="123",
366+
category="category",
367+
objective="task",
368+
)
369+
370+
# Get the message that was sent
371+
call_args = chat_target.send_prompt_async.call_args
372+
sent_message = call_args.kwargs["message"]
373+
374+
# Should have only 1 piece
375+
assert len(sent_message.message_pieces) == 1
376+
377+
# The piece should be text with the full message
378+
text_piece = sent_message.message_pieces[0]
379+
assert text_piece.converted_value_data_type == "text"
380+
assert "objective: test" in text_piece.original_value
381+
assert "response: some text" in text_piece.original_value
382+
383+
384+
@pytest.mark.asyncio
385+
async def test_score_value_with_llm_prepended_text_works_with_audio(good_json):
386+
"""Test that prepended_text_message_piece works with audio content (type-independent)."""
387+
chat_target = MagicMock(PromptChatTarget)
388+
good_json_resp = Message(
389+
message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")]
390+
)
391+
chat_target.send_prompt_async = AsyncMock(return_value=[good_json_resp])
392+
393+
scorer = MockScorer()
394+
395+
await scorer._score_value_with_llm(
396+
prompt_target=chat_target,
397+
system_prompt="system_prompt",
398+
message_value="test_audio.wav",
399+
message_data_type="audio_path",
400+
scored_prompt_id="123",
401+
prepended_text_message_piece="objective: transcribe and evaluate\nresponse:",
402+
category="category",
403+
objective="task",
404+
)
405+
406+
# Get the message that was sent
407+
call_args = chat_target.send_prompt_async.call_args
408+
sent_message = call_args.kwargs["message"]
409+
410+
# Should have 2 pieces: text context + audio
411+
assert len(sent_message.message_pieces) == 2
412+
413+
# First piece should be text context
414+
text_piece = sent_message.message_pieces[0]
415+
assert text_piece.converted_value_data_type == "text"
416+
417+
# Second piece should be audio
418+
audio_piece = sent_message.message_pieces[1]
419+
assert audio_piece.converted_value_data_type == "audio_path"
420+
assert audio_piece.original_value == "test_audio.wav"
421+
422+
306423
def test_scorer_extract_task_from_response(patch_central_database):
307424
"""
308425
Test that _extract_task_from_response properly gathers text from the

0 commit comments

Comments
 (0)