Skip to content

Commit b89c0a2

Browse files
authored
Simulator quiet run and progress bar (#34809)
* Reverting models to make sure calls to the simulator work * quotes * Spellcheck fixes * ignore the models for doc generation * Fixed the quotes on f strings * pylint skip file * Support for summarization * Adding a limit of 2 conversation turns for all but conversation simulators * exclude synthetic from mypy * Another lint fix * Skip the file causing linting issues * Bugfix on output to json_qa_lines and empty response from callbacks * Skip pylint * Add if/else on message to eval json util * Remove the verbose logs, add a progress bar and display a message when requested responses > rai service responses * Added back missing param * Spellcheck fix for pbar * Adding parameter to the spellcheck * Simulator logging initialized on the top of the page and reused * Clean up the warning log message * Removing logs from AsyncHttpClientWithRetry * Using the right param
1 parent ae4b125 commit b89c0a2

File tree

5 files changed

+117
-45
lines changed

5 files changed

+117
-45
lines changed

sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ async def simulate_conversation(
9090
else:
9191
conversation_id = None
9292
first_prompt = first_response["samples"][0]
93-
logger.info(f"First turn: {first_prompt}")
94-
9593
# Add all generated turns into array to pass for each bot while generating
9694
# new responses. We add generated response and the person generating it.
9795
# in the case of the first turn, it is supposed to be the user search query
@@ -115,7 +113,6 @@ async def simulate_conversation(
115113
current_character_idx = current_turn % len(bots)
116114
current_bot = bots[current_character_idx]
117115
# invoke Bot to generate response given the input request
118-
logger.info("-- Sending to %s", current_bot.role.value)
119116
# pass only the last generated turn without passing the bot name.
120117
response, request, time_taken, full_response = await current_bot.generate_response(
121118
session=session,
@@ -137,7 +134,6 @@ async def simulate_conversation(
137134
request=request,
138135
)
139136
)
140-
logger.info("Last turn: %s", conversation_history[-1])
141137
if mlflow_logger is not None:
142138
logger_tasks.append( # schedule logging but don't get blocked by it
143139
asyncio.create_task(mlflow_logger.log_successful_response(time_taken))

sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation_bot.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,13 @@ def __init__(
5757
self.model = model
5858

5959
self.logger = logging.getLogger(repr(self))
60-
6160
self.conversation_starter = None # can either be a dictionary or jinja template
6261
if role == ConversationRole.USER:
6362
if "conversation_starter" in self.persona_template_args:
6463
conversation_starter_content = self.persona_template_args["conversation_starter"]
6564
if isinstance(conversation_starter_content, dict):
66-
msg = f"{conversation_starter_content} instead of generating a turn using a LLM"
67-
self.logger.info(
68-
"This simulated bot will use the provided conversation starter (passed in as dictionary): %s",
69-
msg,
70-
)
7165
self.conversation_starter = conversation_starter_content
7266
else:
73-
msg = f"{repr(conversation_starter_content)[:400]} instead of generating a turn using a LLM"
74-
self.logger.info("This simulated bot will use the provided conversation starter %s", msg)
7567
self.conversation_starter = jinja2.Template(
7668
conversation_starter_content, undefined=jinja2.StrictUndefined
7769
)
@@ -107,12 +99,8 @@ async def generate_response(
10799
if turn_number == 0 and self.conversation_starter is not None:
108100
# if conversation_starter is a dictionary, pass it into samples as is
109101
if isinstance(self.conversation_starter, dict):
110-
self.logger.info("Returning conversation starter: %s", self.conversation_starter)
111102
samples = [self.conversation_starter]
112103
else:
113-
self.logger.info(
114-
"Returning conversation starter: %s", repr(self.persona_template_args["conversation_starter"])[:400]
115-
)
116104
samples = [self.conversation_starter.render(**self.persona_template_args)] # type: ignore[attr-defined]
117105
time_taken = 0
118106

sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_model_tools/models.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def __init__(self, n_retry, retry_timeout, logger, retry_options=None):
4747
# Set up async HTTP client with retry
4848

4949
trace_config = TraceConfig() # set up request logging
50-
trace_config.on_request_start.append(self.on_request_start)
51-
trace_config.on_request_end.append(self.on_request_end)
50+
trace_config.on_request_end.append(self.delete_auth_header)
51+
# trace_config.on_request_start.append(self.on_request_start)
52+
# trace_config.on_request_end.append(self.on_request_end)
5253
if retry_options is None:
5354
retry_options = RandomRetry( # set up retry configuration
5455
statuses=[104, 408, 409, 424, 429, 500, 502,
@@ -67,6 +68,13 @@ async def on_request_start(self, session, trace_config_ctx, params):
6768
current_attempt, params.method, params.url
6869
))
6970

71+
async def delete_auth_header(self, session, trace_config_ctx, params):
72+
request_headers = dict(params.response.request_info.headers)
73+
if "Authorization" in request_headers:
74+
del request_headers["Authorization"]
75+
if "api-key" in request_headers:
76+
del request_headers["api-key"]
77+
7078
async def on_request_end(self, session, trace_config_ctx, params):
7179
current_attempt = trace_config_ctx.trace_request_ctx["current_attempt"]
7280
request_headers = dict(params.response.request_info.headers)

sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/simulator/simulator.py

Lines changed: 106 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4-
# pylint: disable=E0401
4+
# pylint: skip-file
55
# needed for 'list' type annotations on 3.8
66
from __future__ import annotations
77

@@ -12,6 +12,9 @@
1212
import threading
1313
import json
1414
import random
15+
from tqdm import tqdm
16+
17+
logger = logging.getLogger(__name__)
1518

1619
from azure.ai.generative.synthetic.simulator._conversation import (
1720
ConversationBot,
@@ -78,7 +81,9 @@ def __init__(
7881
if (ai_client is None and simulator_connection is None) or (
7982
ai_client is not None and simulator_connection is not None
8083
):
81-
raise ValueError("One and only one of the parameters [ai_client, simulator_connection] has to be set.")
84+
raise ValueError(
85+
"One and only one of the parameters [ai_client, simulator_connection] has to be set."
86+
)
8287

8388
if simulate_callback is None:
8489
raise ValueError("Callback cannot be None.")
@@ -87,7 +92,9 @@ def __init__(
8792
raise ValueError("Callback has to be an async function.")
8893

8994
self.ai_client = ai_client
90-
self.simulator_connection = self._to_openai_chat_completion_model(simulator_connection)
95+
self.simulator_connection = self._to_openai_chat_completion_model(
96+
simulator_connection
97+
)
9198
self.adversarial = False
9299
self.rai_client = None
93100
if ai_client:
@@ -168,19 +175,33 @@ def _create_bot(
168175
instantiation_parameters=instantiation_parameters,
169176
)
170177

171-
def _setup_bot(self, role: Union[str, ConversationRole], template: "Template", parameters: dict):
178+
def _setup_bot(
179+
self,
180+
role: Union[str, ConversationRole],
181+
template: "Template",
182+
parameters: dict,
183+
):
172184
if role == ConversationRole.ASSISTANT:
173185
return self._create_bot(role, str(template), parameters)
174186
if role == ConversationRole.USER:
175187
if template.content_harm:
176-
return self._create_bot(role, str(template), parameters, template.template_name)
188+
return self._create_bot(
189+
role, str(template), parameters, template.template_name
190+
)
177191

178-
return self._create_bot(role, str(template), parameters, model=self.simulator_connection)
192+
return self._create_bot(
193+
role,
194+
str(template),
195+
parameters,
196+
model=self.simulator_connection,
197+
)
179198
return None
180199

181200
def _ensure_service_dependencies(self):
182201
if self.rai_client is None:
183-
raise ValueError("Simulation options require rai services but ai client is not provided.")
202+
raise ValueError(
203+
"Simulation options require rai services but ai client is not provided."
204+
)
184205

185206
def _join_conversation_starter(self, parameters, to_join):
186207
key = "conversation_starter"
@@ -236,30 +257,56 @@ async def simulate_async(
236257
if parameters is None:
237258
parameters = []
238259
if not isinstance(template, Template):
239-
raise ValueError(f"Please use simulator to construct template. Found {type(template)}")
260+
raise ValueError(
261+
f"Please use simulator to construct template. Found {type(template)}"
262+
)
240263

241264
if not isinstance(parameters, list):
242-
raise ValueError(f"Expect parameters to be a list of dictionary, but found {type(parameters)}")
265+
raise ValueError(
266+
f"Expect parameters to be a list of dictionary, but found {type(parameters)}"
267+
)
243268
if "conversation" not in template.template_name:
244269
max_conversation_turns = 2
245270
if template.content_harm:
246271
self._ensure_service_dependencies()
247272
self.adversarial = True
248273
# pylint: disable=protected-access
249-
templates = await self.template_handler._get_ch_template_collections(template.template_name)
274+
templates = await self.template_handler._get_ch_template_collections(
275+
template.template_name
276+
)
250277
else:
251278
template.template_parameters = parameters
252279
templates = [template]
253280

254281
semaphore = asyncio.Semaphore(concurrent_async_task)
255282
sim_results = []
256283
tasks = []
284+
total_tasks = sum(len(t.template_parameters) for t in templates)
285+
286+
if simulation_result_limit > total_tasks and self.adversarial:
287+
logger.warning(
288+
"Cannot provide %s results due to maximum number of adversarial simulations that can be generated: %s."
289+
"\n %s simulations will be generated.",
290+
simulation_result_limit,
291+
total_tasks,
292+
total_tasks,
293+
)
294+
total_tasks = min(total_tasks, simulation_result_limit)
295+
progress_bar = tqdm(
296+
total=total_tasks,
297+
desc="generating simulations",
298+
ncols=100,
299+
unit="simulations",
300+
)
301+
257302
for t in templates:
258303
for p in t.template_parameters:
259304
if jailbreak:
260305
self._ensure_service_dependencies()
261306
jailbreak_dataset = await self.rai_client.get_jailbreaks_dataset() # type: ignore[union-attr]
262-
p = self._join_conversation_starter(p, random.choice(jailbreak_dataset))
307+
p = self._join_conversation_starter(
308+
p, random.choice(jailbreak_dataset)
309+
)
263310

264311
tasks.append(
265312
asyncio.create_task(
@@ -280,7 +327,15 @@ async def simulate_async(
280327
if len(tasks) >= simulation_result_limit:
281328
break
282329

283-
sim_results = await asyncio.gather(*tasks)
330+
sim_results = []
331+
332+
# Use asyncio.as_completed to update the progress bar when a task is complete
333+
for task in asyncio.as_completed(tasks):
334+
result = await task
335+
sim_results.append(result) # Store the result
336+
progress_bar.update(1)
337+
338+
progress_bar.close()
284339

285340
return JsonLineList(sim_results)
286341

@@ -319,7 +374,9 @@ async def _simulate_async(
319374
parameters = {}
320375
# create user bot
321376
user_bot = self._setup_bot(ConversationRole.USER, template, parameters)
322-
system_bot = self._setup_bot(ConversationRole.ASSISTANT, template, parameters)
377+
system_bot = self._setup_bot(
378+
ConversationRole.ASSISTANT, template, parameters
379+
)
323380

324381
bots = [user_bot, system_bot]
325382

@@ -328,7 +385,7 @@ async def _simulate_async(
328385
asyncHttpClient = AsyncHTTPClientWithRetry(
329386
n_retry=api_call_retry_limit,
330387
retry_timeout=api_call_retry_sleep_sec,
331-
logger=logging.getLogger(),
388+
logger=logger,
332389
)
333390
async with sem:
334391
async with asyncHttpClient.client as session:
@@ -357,7 +414,9 @@ def _get_citations(self, parameters, context_keys, turn_num=None):
357414
else:
358415
for k, v in parameters[c_key].items():
359416
if k not in ["callback_citations", "callback_citation_key"]:
360-
citations.append({"id": k, "content": self._to_citation_content(v)})
417+
citations.append(
418+
{"id": k, "content": self._to_citation_content(v)}
419+
)
361420
else:
362421
citations.append(
363422
{
@@ -373,7 +432,9 @@ def _to_citation_content(self, obj):
373432
return obj
374433
return json.dumps(obj)
375434

376-
def _get_callback_citations(self, callback_citations: dict, turn_num: Optional[int] = None):
435+
def _get_callback_citations(
436+
self, callback_citations: dict, turn_num: Optional[int] = None
437+
):
377438
if turn_num is None:
378439
return []
379440
current_turn_citations = []
@@ -382,7 +443,9 @@ def _get_callback_citations(self, callback_citations: dict, turn_num: Optional[i
382443
citations = callback_citations[current_turn_str]
383444
if isinstance(citations, dict):
384445
for k, v in citations.items():
385-
current_turn_citations.append({"id": k, "content": self._to_citation_content(v)})
446+
current_turn_citations.append(
447+
{"id": k, "content": self._to_citation_content(v)}
448+
)
386449
else:
387450
current_turn_citations.append(
388451
{
@@ -397,13 +460,15 @@ def _to_chat_protocol(self, template, conversation_history, template_parameters)
397460
for i, m in enumerate(conversation_history):
398461
message = {"content": m.message, "role": m.role.value}
399462
if len(template.context_key) > 0:
400-
citations = self._get_citations(template_parameters, template.context_key, i)
463+
citations = self._get_citations(
464+
template_parameters, template.context_key, i
465+
)
401466
message["context"] = citations
402467
elif "context" in m.full_response:
403468
# adding context for adv_qa
404469
message["context"] = m.full_response["context"]
405470
messages.append(message)
406-
template_parameters['metadata'] = {}
471+
template_parameters["metadata"] = {}
407472
if "ch_template_placeholder" in template_parameters:
408473
del template_parameters["ch_template_placeholder"]
409474

@@ -524,8 +589,13 @@ def from_fn(
524589
if hasattr(fn, "__wrapped__"):
525590
func_module = fn.__wrapped__.__module__
526591
func_name = fn.__wrapped__.__name__
527-
if func_module == "openai.resources.chat.completions" and func_name == "create":
528-
return Simulator._from_openai_chat_completions(fn, simulator_connection, ai_client, **kwargs)
592+
if (
593+
func_module == "openai.resources.chat.completions"
594+
and func_name == "create"
595+
):
596+
return Simulator._from_openai_chat_completions(
597+
fn, simulator_connection, ai_client, **kwargs
598+
)
529599

530600
return Simulator(
531601
simulator_connection=simulator_connection,
@@ -534,7 +604,9 @@ def from_fn(
534604
)
535605

536606
@staticmethod
537-
def _from_openai_chat_completions(fn: Callable[[Any], dict], simulator_connection=None, ai_client=None, **kwargs):
607+
def _from_openai_chat_completions(
608+
fn: Callable[[Any], dict], simulator_connection=None, ai_client=None, **kwargs
609+
):
538610
return Simulator(
539611
simulator_connection=simulator_connection,
540612
ai_client=ai_client,
@@ -625,7 +697,9 @@ async def callback(chat_protocol_message):
625697
input_data[chat_history_key] = all_messages
626698

627699
response = flow.invoke(input_data).output
628-
chat_protocol_message["messages"].append({"role": "assistant", "content": response[chat_output_key]})
700+
chat_protocol_message["messages"].append(
701+
{"role": "assistant", "content": response[chat_output_key]}
702+
)
629703

630704
return chat_protocol_message
631705

@@ -657,8 +731,12 @@ def create_template(
657731
One of 'template' or 'template_path' must be provided to create a template. If 'template' is provided,
658732
it is used directly; if 'template_path' is provided, the content is read from the file at that path.
659733
"""
660-
if (template is None and template_path is None) or (template is not None and template_path is not None):
661-
raise ValueError("One and only one of the parameters [template, template_path] has to be set.")
734+
if (template is None and template_path is None) or (
735+
template is not None and template_path is not None
736+
):
737+
raise ValueError(
738+
"One and only one of the parameters [template, template_path] has to be set."
739+
)
662740

663741
if template is not None:
664742
return Template(template_name=name, text=template, context_key=context_key)
@@ -669,7 +747,9 @@ def create_template(
669747

670748
return Template(template_name=name, text=tc, context_key=context_key)
671749

672-
raise ValueError("Condition not met for creating template, please check examples and parameter list.")
750+
raise ValueError(
751+
"Condition not met for creating template, please check examples and parameter list."
752+
)
673753

674754
@staticmethod
675755
def get_template(template_name: str):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
2-
"ignoreWords": ["cmpl", "uqkvl", "redef", "datas", "unbatched", "endofprompt", "unlabel", "pydash", "raisvc", "tkey", "tparam", "punc"],
2+
"ignoreWords": ["cmpl", "uqkvl", "redef", "datas", "unbatched", "endofprompt", "unlabel", "pydash", "raisvc", "tkey", "tparam", "punc", "ncols"],
33
"ignorePaths": ["sdk/ai/azure-ai-generative/azure/ai/generative/evaluate/pf_templates/**/*"]
44
}

0 commit comments

Comments
 (0)