Skip to content

Commit 3497eff

Browse files
authored
Multimodal Adv Simulator for Image-Gen-Understanding (Azure#38584)
* sim-multi-modal * fix * unit test fix * adding tests recording * test recording * fix lint * assets * skip-test * asset * asset * refactor-after-nag-comments * refactor-after-nag-comments * test fix * asset * Fix with comments * refactor * conf-test-fix * removing logs
1 parent 48b0e8a commit 3497eff

File tree

12 files changed

+622
-31
lines changed

12 files changed

+622
-31
lines changed

sdk/evaluation/azure-ai-evaluation/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/evaluation/azure-ai-evaluation",
5-
"Tag": "python/evaluation/azure-ai-evaluation_fdb88346b8"
5+
"Tag": "python/evaluation/azure-ai-evaluation_a63b4a27cf"
66
}

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_utils.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,40 @@ def _store_multimodal_content(messages, tmpdir: str):
9191
for message in messages:
9292
if isinstance(message.get("content", []), list):
9393
for content in message.get("content", []):
94-
if content.get("type") == "image_url":
95-
image_url = content.get("image_url")
96-
if image_url and "url" in image_url and image_url["url"].startswith("data:image/jpg;base64,"):
97-
# Extract the base64 string
98-
base64image = image_url["url"].replace("data:image/jpg;base64,", "")
99-
100-
# Generate a unique filename
101-
image_file_name = f"{str(uuid.uuid4())}.jpg"
102-
image_url["url"] = f"images/{image_file_name}" # Replace the base64 URL with the file path
103-
104-
# Decode the base64 string to binary image data
105-
image_data_binary = base64.b64decode(base64image)
106-
107-
# Write the binary image data to the file
108-
image_file_path = os.path.join(images_folder_path, image_file_name)
109-
with open(image_file_path, "wb") as f:
110-
f.write(image_data_binary)
94+
process_message_content(content, images_folder_path)
95+
96+
97+
def process_message_content(content, images_folder_path):
98+
if content.get("type", "") == "image_url":
99+
image_url = content.get("image_url")
100+
101+
if not image_url or "url" not in image_url:
102+
return None
103+
104+
url = image_url["url"]
105+
if not url.startswith("data:image/"):
106+
return None
107+
108+
match = re.search("data:image/([^;]+);", url)
109+
if not match:
110+
return None
111+
112+
ext = match.group(1)
113+
# Extract the base64 string
114+
base64image = image_url["url"].replace(f"data:image/{ext};base64,", "")
115+
116+
# Generate a unique filename
117+
image_file_name = f"{str(uuid.uuid4())}.{ext}"
118+
image_url["url"] = f"images/{image_file_name}" # Replace the base64 URL with the file path
119+
120+
# Decode the base64 string to binary image data
121+
image_data_binary = base64.b64decode(base64image)
122+
123+
# Write the binary image data to the file
124+
image_file_path = os.path.join(images_folder_path, image_file_name)
125+
with open(image_file_path, "wb") as f:
126+
f.write(image_data_binary)
127+
return None
111128

112129

113130
def _log_metrics_and_instance_results(

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_adversarial_scenario.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class AdversarialScenario(Enum):
2828
ADVERSARIAL_CONTENT_GEN_UNGROUNDED = "adv_content_gen_ungrounded"
2929
ADVERSARIAL_CONTENT_GEN_GROUNDED = "adv_content_gen_grounded"
3030
ADVERSARIAL_CONTENT_PROTECTED_MATERIAL = "adv_content_protected_material"
31+
ADVERSARIAL_IMAGE_GEN = "adv_image_gen"
32+
ADVERSARIAL_IMAGE_UNDERSTANDING = "adv_image_understanding"
3133

3234

3335
@experimental

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_adversarial_simulator.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@
1616
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
1717
from azure.ai.evaluation._http_utils import get_async_http_client
1818
from azure.ai.evaluation._model_configurations import AzureAIProject
19-
from azure.ai.evaluation.simulator import AdversarialScenario
19+
from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialScenarioJailbreak
2020
from azure.ai.evaluation.simulator._adversarial_scenario import _UnstableAdversarialScenario
2121
from azure.core.credentials import TokenCredential
2222
from azure.core.pipeline.policies import AsyncRetryPolicy, RetryMode
2323

2424
from ._constants import SupportedLanguages
25-
from ._conversation import CallbackConversationBot, ConversationBot, ConversationRole, ConversationTurn
25+
from ._conversation import (
26+
CallbackConversationBot,
27+
MultiModalConversationBot,
28+
ConversationBot,
29+
ConversationRole,
30+
ConversationTurn,
31+
)
2632
from ._conversation._conversation import simulate_conversation
2733
from ._model_tools import (
2834
AdversarialTemplateHandler,
@@ -231,6 +237,7 @@ async def __call__(
231237
api_call_delay_sec=api_call_delay_sec,
232238
language=language,
233239
semaphore=semaphore,
240+
scenario=scenario,
234241
)
235242
)
236243
)
@@ -292,10 +299,13 @@ async def _simulate_async(
292299
api_call_delay_sec: int,
293300
language: SupportedLanguages,
294301
semaphore: asyncio.Semaphore,
302+
scenario: Union[AdversarialScenario, AdversarialScenarioJailbreak],
295303
) -> List[Dict]:
296-
user_bot = self._setup_bot(role=ConversationRole.USER, template=template, parameters=parameters)
304+
user_bot = self._setup_bot(
305+
role=ConversationRole.USER, template=template, parameters=parameters, scenario=scenario
306+
)
297307
system_bot = self._setup_bot(
298-
target=target, role=ConversationRole.ASSISTANT, template=template, parameters=parameters
308+
target=target, role=ConversationRole.ASSISTANT, template=template, parameters=parameters, scenario=scenario
299309
)
300310
bots = [user_bot, system_bot]
301311
session = get_async_http_client().with_policies(
@@ -341,6 +351,7 @@ def _setup_bot(
341351
template: AdversarialTemplate,
342352
parameters: TemplateParameters,
343353
target: Optional[Callable] = None,
354+
scenario: Union[AdversarialScenario, AdversarialScenarioJailbreak],
344355
) -> ConversationBot:
345356
if role is ConversationRole.USER:
346357
model = self._get_user_proxy_completion_model(
@@ -372,6 +383,21 @@ def __init__(self):
372383
def __call__(self) -> None:
373384
pass
374385

386+
if scenario in [
387+
AdversarialScenario.ADVERSARIAL_IMAGE_GEN,
388+
AdversarialScenario.ADVERSARIAL_IMAGE_UNDERSTANDING,
389+
]:
390+
return MultiModalConversationBot(
391+
callback=target,
392+
role=role,
393+
model=DummyModel(),
394+
user_template=str(template),
395+
user_template_parameters=parameters,
396+
rai_client=self.rai_client,
397+
conversation_template="",
398+
instantiation_parameters={},
399+
)
400+
375401
return CallbackConversationBot(
376402
callback=target,
377403
role=role,

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_conversation/__init__.py

Lines changed: 120 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from dataclasses import dataclass
1010
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
1111

12+
import re
1213
import jinja2
1314

1415
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
1516
from azure.ai.evaluation._http_utils import AsyncHttpPipeline
16-
17-
from .._model_tools import LLMBase, OpenAIChatCompletionsModel
17+
from .._model_tools import LLMBase, OpenAIChatCompletionsModel, RAIClient
1818
from .._model_tools._template_handler import TemplateParameters
1919
from .constants import ConversationRole
2020

@@ -271,8 +271,6 @@ async def generate_response(
271271
"id": None,
272272
"template_parameters": {},
273273
}
274-
self.logger.info("Using user provided callback returning response.")
275-
276274
time_taken = end_time - start_time
277275
try:
278276
response = {
@@ -290,8 +288,6 @@ async def generate_response(
290288
blame=ErrorBlame.USER_ERROR,
291289
) from exc
292290

293-
self.logger.info("Parsed callback response")
294-
295291
return response, {}, time_taken, result
296292

297293
# Bug 3354264: template is unused in the method - is this intentional?
@@ -308,9 +304,127 @@ def _to_chat_protocol(self, template, conversation_history, template_parameters)
308304
}
309305

310306

307+
class MultiModalConversationBot(ConversationBot):
308+
"""MultiModal Conversation bot that uses a user provided callback to generate responses.
309+
310+
:param callback: The callback function to use to generate responses.
311+
:type callback: Callable
312+
:param user_template: The template to use for the request.
313+
:type user_template: str
314+
:param user_template_parameters: The template parameters to use for the request.
315+
:type user_template_parameters: Dict
316+
:param args: Optional arguments to pass to the parent class.
317+
:type args: Any
318+
:param kwargs: Optional keyword arguments to pass to the parent class.
319+
:type kwargs: Any
320+
"""
321+
322+
def __init__(
323+
self,
324+
callback: Callable,
325+
user_template: str,
326+
user_template_parameters: TemplateParameters,
327+
rai_client: RAIClient,
328+
*args,
329+
**kwargs,
330+
) -> None:
331+
self.callback = callback
332+
self.user_template = user_template
333+
self.user_template_parameters = user_template_parameters
334+
self.rai_client = rai_client
335+
336+
super().__init__(*args, **kwargs)
337+
338+
async def generate_response(
339+
self,
340+
session: AsyncHttpPipeline,
341+
conversation_history: List[Any],
342+
max_history: int,
343+
turn_number: int = 0,
344+
) -> Tuple[dict, dict, float, dict]:
345+
previous_prompt = conversation_history[-1]
346+
chat_protocol_message = await self._to_chat_protocol(conversation_history, self.user_template_parameters)
347+
348+
# replace prompt with {image.jpg} tags with image content data.
349+
conversation_history.pop()
350+
conversation_history.append(
351+
ConversationTurn(
352+
role=previous_prompt.role,
353+
name=previous_prompt.name,
354+
message=chat_protocol_message["messages"][0]["content"],
355+
full_response=previous_prompt.full_response,
356+
request=chat_protocol_message,
357+
)
358+
)
359+
msg_copy = copy.deepcopy(chat_protocol_message)
360+
result = {}
361+
start_time = time.time()
362+
result = await self.callback(msg_copy)
363+
end_time = time.time()
364+
if not result:
365+
result = {
366+
"messages": [{"content": "Callback did not return a response.", "role": "assistant"}],
367+
"finish_reason": ["stop"],
368+
"id": None,
369+
"template_parameters": {},
370+
}
371+
372+
time_taken = end_time - start_time
373+
try:
374+
response = {
375+
"samples": [result["messages"][-1]["content"]],
376+
"finish_reason": ["stop"],
377+
"id": None,
378+
}
379+
except Exception as exc:
380+
msg = "User provided callback does not conform to chat protocol standard."
381+
raise EvaluationException(
382+
message=msg,
383+
internal_message=msg,
384+
target=ErrorTarget.CALLBACK_CONVERSATION_BOT,
385+
category=ErrorCategory.INVALID_VALUE,
386+
blame=ErrorBlame.USER_ERROR,
387+
) from exc
388+
389+
return response, chat_protocol_message, time_taken, result
390+
391+
async def _to_chat_protocol(self, conversation_history, template_parameters): # pylint: disable=unused-argument
392+
messages = []
393+
394+
for _, m in enumerate(conversation_history):
395+
if "image:" in m.message:
396+
content = await self._to_multi_modal_content(m.message)
397+
messages.append({"content": content, "role": m.role.value})
398+
else:
399+
messages.append({"content": m.message, "role": m.role.value})
400+
401+
return {
402+
"template_parameters": template_parameters,
403+
"messages": messages,
404+
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
405+
}
406+
407+
async def _to_multi_modal_content(self, text: str) -> list:
408+
split_text = re.findall(r"[^{}]+|\{[^{}]*\}", text)
409+
messages = [
410+
text.strip("{}").replace("image:", "").strip() if text.startswith("{") else text for text in split_text
411+
]
412+
contents = []
413+
for msg in messages:
414+
if msg.startswith("image_understanding/"):
415+
encoded_image = await self.rai_client.get_image_data(msg)
416+
contents.append(
417+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}},
418+
)
419+
else:
420+
contents.append({"type": "text", "text": msg})
421+
return contents
422+
423+
311424
__all__ = [
312425
"ConversationRole",
313426
"ConversationBot",
314427
"CallbackConversationBot",
428+
"MultiModalConversationBot",
315429
"ConversationTurn",
316430
]

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_conversation/_conversation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
1010
from azure.ai.evaluation.simulator._constants import SupportedLanguages
1111
from azure.ai.evaluation.simulator._helpers._language_suffix_mapping import SUPPORTED_LANGUAGES_MAPPING
12-
1312
from ..._http_utils import AsyncHttpPipeline
1413
from . import ConversationBot, ConversationTurn
1514

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_indirect_attack_simulator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ async def __call__(
189189
api_call_delay_sec=api_call_delay_sec,
190190
language=language,
191191
semaphore=semaphore,
192+
scenario=scenario,
192193
)
193194
)
194195
)

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_rai_client.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
from typing import Any
66
from urllib.parse import urljoin, urlparse
7+
import base64
78

89
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
910
from azure.ai.evaluation._http_utils import AsyncHttpPipeline, get_async_http_client, get_http_client
@@ -57,6 +58,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
5758
# add a "/" at the end of the url
5859
self.api_url = self.api_url.rstrip("/") + "/"
5960
self.parameter_json_endpoint = urljoin(self.api_url, "simulation/template/parameters")
61+
self.parameter_image_endpoint = urljoin(self.api_url, "simulation/template/parameters/image")
6062
self.jailbreaks_json_endpoint = urljoin(self.api_url, "simulation/jailbreak")
6163
self.simulation_submit_endpoint = urljoin(self.api_url, "simulation/chat/completions/submit")
6264
self.xpia_jailbreaks_json_endpoint = urljoin(self.api_url, "simulation/jailbreak/xpia")
@@ -166,3 +168,41 @@ async def get(self, url: str) -> Any:
166168
category=ErrorCategory.UNKNOWN,
167169
blame=ErrorBlame.USER_ERROR,
168170
)
171+
172+
async def get_image_data(self, path: str) -> Any:
173+
"""Make a GET Image request to the given url
174+
175+
:param path: The url of the image
176+
:type path: str
177+
:raises EvaluationException: If the Azure safety evaluation service is not available in the current region
178+
:return: The response
179+
:rtype: Any
180+
"""
181+
token = self.token_manager.get_token()
182+
headers = {
183+
"Authorization": f"Bearer {token}",
184+
"Content-Type": "application/json",
185+
"User-Agent": USER_AGENT,
186+
}
187+
188+
session = self._create_async_client()
189+
params = {"path": path}
190+
async with session:
191+
response = await session.get(
192+
url=self.parameter_image_endpoint, params=params, headers=headers
193+
) # pylint: disable=unexpected-keyword-arg
194+
195+
if response.status_code == 200:
196+
return base64.b64encode(response.content).decode("utf-8")
197+
198+
msg = (
199+
"Azure safety evaluation service is not available in your current region, "
200+
+ "please go to https://aka.ms/azureaistudiosafetyeval to see which regions are supported"
201+
)
202+
raise EvaluationException(
203+
message=msg,
204+
internal_message=msg,
205+
target=ErrorTarget.RAI_CLIENT,
206+
category=ErrorCategory.UNKNOWN,
207+
blame=ErrorBlame.USER_ERROR,
208+
)

0 commit comments

Comments
 (0)