1
1
# ---------------------------------------------------------
2
2
# Copyright (c) Microsoft Corporation. All rights reserved.
3
3
# ---------------------------------------------------------
4
- # pylint: disable=E0401
4
+ # pylint: skip-file
5
5
# needed for 'list' type annotations on 3.8
6
6
from __future__ import annotations
7
7
12
12
import threading
13
13
import json
14
14
import random
15
+ from tqdm import tqdm
16
+
17
+ logger = logging .getLogger (__name__ )
15
18
16
19
from azure .ai .generative .synthetic .simulator ._conversation import (
17
20
ConversationBot ,
@@ -78,7 +81,9 @@ def __init__(
78
81
if (ai_client is None and simulator_connection is None ) or (
79
82
ai_client is not None and simulator_connection is not None
80
83
):
81
- raise ValueError ("One and only one of the parameters [ai_client, simulator_connection] has to be set." )
84
+ raise ValueError (
85
+ "One and only one of the parameters [ai_client, simulator_connection] has to be set."
86
+ )
82
87
83
88
if simulate_callback is None :
84
89
raise ValueError ("Callback cannot be None." )
@@ -87,7 +92,9 @@ def __init__(
87
92
raise ValueError ("Callback has to be an async function." )
88
93
89
94
self .ai_client = ai_client
90
- self .simulator_connection = self ._to_openai_chat_completion_model (simulator_connection )
95
+ self .simulator_connection = self ._to_openai_chat_completion_model (
96
+ simulator_connection
97
+ )
91
98
self .adversarial = False
92
99
self .rai_client = None
93
100
if ai_client :
@@ -168,19 +175,33 @@ def _create_bot(
168
175
instantiation_parameters = instantiation_parameters ,
169
176
)
170
177
171
- def _setup_bot (self , role : Union [str , ConversationRole ], template : "Template" , parameters : dict ):
178
+ def _setup_bot (
179
+ self ,
180
+ role : Union [str , ConversationRole ],
181
+ template : "Template" ,
182
+ parameters : dict ,
183
+ ):
172
184
if role == ConversationRole .ASSISTANT :
173
185
return self ._create_bot (role , str (template ), parameters )
174
186
if role == ConversationRole .USER :
175
187
if template .content_harm :
176
- return self ._create_bot (role , str (template ), parameters , template .template_name )
188
+ return self ._create_bot (
189
+ role , str (template ), parameters , template .template_name
190
+ )
177
191
178
- return self ._create_bot (role , str (template ), parameters , model = self .simulator_connection )
192
+ return self ._create_bot (
193
+ role ,
194
+ str (template ),
195
+ parameters ,
196
+ model = self .simulator_connection ,
197
+ )
179
198
return None
180
199
181
200
def _ensure_service_dependencies (self ):
182
201
if self .rai_client is None :
183
- raise ValueError ("Simulation options require rai services but ai client is not provided." )
202
+ raise ValueError (
203
+ "Simulation options require rai services but ai client is not provided."
204
+ )
184
205
185
206
def _join_conversation_starter (self , parameters , to_join ):
186
207
key = "conversation_starter"
@@ -236,30 +257,56 @@ async def simulate_async(
236
257
if parameters is None :
237
258
parameters = []
238
259
if not isinstance (template , Template ):
239
- raise ValueError (f"Please use simulator to construct template. Found { type (template )} " )
260
+ raise ValueError (
261
+ f"Please use simulator to construct template. Found { type (template )} "
262
+ )
240
263
241
264
if not isinstance (parameters , list ):
242
- raise ValueError (f"Expect parameters to be a list of dictionary, but found { type (parameters )} " )
265
+ raise ValueError (
266
+ f"Expect parameters to be a list of dictionary, but found { type (parameters )} "
267
+ )
243
268
if "conversation" not in template .template_name :
244
269
max_conversation_turns = 2
245
270
if template .content_harm :
246
271
self ._ensure_service_dependencies ()
247
272
self .adversarial = True
248
273
# pylint: disable=protected-access
249
- templates = await self .template_handler ._get_ch_template_collections (template .template_name )
274
+ templates = await self .template_handler ._get_ch_template_collections (
275
+ template .template_name
276
+ )
250
277
else :
251
278
template .template_parameters = parameters
252
279
templates = [template ]
253
280
254
281
semaphore = asyncio .Semaphore (concurrent_async_task )
255
282
sim_results = []
256
283
tasks = []
284
+ total_tasks = sum (len (t .template_parameters ) for t in templates )
285
+
286
+ if simulation_result_limit > total_tasks and self .adversarial :
287
+ logger .warning (
288
+ "Cannot provide %s results due to maximum number of adversarial simulations that can be generated: %s."
289
+ "\n %s simulations will be generated." ,
290
+ simulation_result_limit ,
291
+ total_tasks ,
292
+ total_tasks ,
293
+ )
294
+ total_tasks = min (total_tasks , simulation_result_limit )
295
+ progress_bar = tqdm (
296
+ total = total_tasks ,
297
+ desc = "generating simulations" ,
298
+ ncols = 100 ,
299
+ unit = "simulations" ,
300
+ )
301
+
257
302
for t in templates :
258
303
for p in t .template_parameters :
259
304
if jailbreak :
260
305
self ._ensure_service_dependencies ()
261
306
jailbreak_dataset = await self .rai_client .get_jailbreaks_dataset () # type: ignore[union-attr]
262
- p = self ._join_conversation_starter (p , random .choice (jailbreak_dataset ))
307
+ p = self ._join_conversation_starter (
308
+ p , random .choice (jailbreak_dataset )
309
+ )
263
310
264
311
tasks .append (
265
312
asyncio .create_task (
@@ -280,7 +327,15 @@ async def simulate_async(
280
327
if len (tasks ) >= simulation_result_limit :
281
328
break
282
329
283
- sim_results = await asyncio .gather (* tasks )
330
+ sim_results = []
331
+
332
+ # Use asyncio.as_completed to update the progress bar when a task is complete
333
+ for task in asyncio .as_completed (tasks ):
334
+ result = await task
335
+ sim_results .append (result ) # Store the result
336
+ progress_bar .update (1 )
337
+
338
+ progress_bar .close ()
284
339
285
340
return JsonLineList (sim_results )
286
341
@@ -319,7 +374,9 @@ async def _simulate_async(
319
374
parameters = {}
320
375
# create user bot
321
376
user_bot = self ._setup_bot (ConversationRole .USER , template , parameters )
322
- system_bot = self ._setup_bot (ConversationRole .ASSISTANT , template , parameters )
377
+ system_bot = self ._setup_bot (
378
+ ConversationRole .ASSISTANT , template , parameters
379
+ )
323
380
324
381
bots = [user_bot , system_bot ]
325
382
@@ -328,7 +385,7 @@ async def _simulate_async(
328
385
asyncHttpClient = AsyncHTTPClientWithRetry (
329
386
n_retry = api_call_retry_limit ,
330
387
retry_timeout = api_call_retry_sleep_sec ,
331
- logger = logging . getLogger () ,
388
+ logger = logger ,
332
389
)
333
390
async with sem :
334
391
async with asyncHttpClient .client as session :
@@ -357,7 +414,9 @@ def _get_citations(self, parameters, context_keys, turn_num=None):
357
414
else :
358
415
for k , v in parameters [c_key ].items ():
359
416
if k not in ["callback_citations" , "callback_citation_key" ]:
360
- citations .append ({"id" : k , "content" : self ._to_citation_content (v )})
417
+ citations .append (
418
+ {"id" : k , "content" : self ._to_citation_content (v )}
419
+ )
361
420
else :
362
421
citations .append (
363
422
{
@@ -373,7 +432,9 @@ def _to_citation_content(self, obj):
373
432
return obj
374
433
return json .dumps (obj )
375
434
376
- def _get_callback_citations (self , callback_citations : dict , turn_num : Optional [int ] = None ):
435
+ def _get_callback_citations (
436
+ self , callback_citations : dict , turn_num : Optional [int ] = None
437
+ ):
377
438
if turn_num is None :
378
439
return []
379
440
current_turn_citations = []
@@ -382,7 +443,9 @@ def _get_callback_citations(self, callback_citations: dict, turn_num: Optional[i
382
443
citations = callback_citations [current_turn_str ]
383
444
if isinstance (citations , dict ):
384
445
for k , v in citations .items ():
385
- current_turn_citations .append ({"id" : k , "content" : self ._to_citation_content (v )})
446
+ current_turn_citations .append (
447
+ {"id" : k , "content" : self ._to_citation_content (v )}
448
+ )
386
449
else :
387
450
current_turn_citations .append (
388
451
{
@@ -397,13 +460,15 @@ def _to_chat_protocol(self, template, conversation_history, template_parameters)
397
460
for i , m in enumerate (conversation_history ):
398
461
message = {"content" : m .message , "role" : m .role .value }
399
462
if len (template .context_key ) > 0 :
400
- citations = self ._get_citations (template_parameters , template .context_key , i )
463
+ citations = self ._get_citations (
464
+ template_parameters , template .context_key , i
465
+ )
401
466
message ["context" ] = citations
402
467
elif "context" in m .full_response :
403
468
# adding context for adv_qa
404
469
message ["context" ] = m .full_response ["context" ]
405
470
messages .append (message )
406
- template_parameters [' metadata' ] = {}
471
+ template_parameters [" metadata" ] = {}
407
472
if "ch_template_placeholder" in template_parameters :
408
473
del template_parameters ["ch_template_placeholder" ]
409
474
@@ -524,8 +589,13 @@ def from_fn(
524
589
if hasattr (fn , "__wrapped__" ):
525
590
func_module = fn .__wrapped__ .__module__
526
591
func_name = fn .__wrapped__ .__name__
527
- if func_module == "openai.resources.chat.completions" and func_name == "create" :
528
- return Simulator ._from_openai_chat_completions (fn , simulator_connection , ai_client , ** kwargs )
592
+ if (
593
+ func_module == "openai.resources.chat.completions"
594
+ and func_name == "create"
595
+ ):
596
+ return Simulator ._from_openai_chat_completions (
597
+ fn , simulator_connection , ai_client , ** kwargs
598
+ )
529
599
530
600
return Simulator (
531
601
simulator_connection = simulator_connection ,
@@ -534,7 +604,9 @@ def from_fn(
534
604
)
535
605
536
606
@staticmethod
537
- def _from_openai_chat_completions (fn : Callable [[Any ], dict ], simulator_connection = None , ai_client = None , ** kwargs ):
607
+ def _from_openai_chat_completions (
608
+ fn : Callable [[Any ], dict ], simulator_connection = None , ai_client = None , ** kwargs
609
+ ):
538
610
return Simulator (
539
611
simulator_connection = simulator_connection ,
540
612
ai_client = ai_client ,
@@ -625,7 +697,9 @@ async def callback(chat_protocol_message):
625
697
input_data [chat_history_key ] = all_messages
626
698
627
699
response = flow .invoke (input_data ).output
628
- chat_protocol_message ["messages" ].append ({"role" : "assistant" , "content" : response [chat_output_key ]})
700
+ chat_protocol_message ["messages" ].append (
701
+ {"role" : "assistant" , "content" : response [chat_output_key ]}
702
+ )
629
703
630
704
return chat_protocol_message
631
705
@@ -657,8 +731,12 @@ def create_template(
657
731
One of 'template' or 'template_path' must be provided to create a template. If 'template' is provided,
658
732
it is used directly; if 'template_path' is provided, the content is read from the file at that path.
659
733
"""
660
- if (template is None and template_path is None ) or (template is not None and template_path is not None ):
661
- raise ValueError ("One and only one of the parameters [template, template_path] has to be set." )
734
+ if (template is None and template_path is None ) or (
735
+ template is not None and template_path is not None
736
+ ):
737
+ raise ValueError (
738
+ "One and only one of the parameters [template, template_path] has to be set."
739
+ )
662
740
663
741
if template is not None :
664
742
return Template (template_name = name , text = template , context_key = context_key )
@@ -669,7 +747,9 @@ def create_template(
669
747
670
748
return Template (template_name = name , text = tc , context_key = context_key )
671
749
672
- raise ValueError ("Condition not met for creating template, please check examples and parameter list." )
750
+ raise ValueError (
751
+ "Condition not met for creating template, please check examples and parameter list."
752
+ )
673
753
674
754
@staticmethod
675
755
def get_template (template_name : str ):
0 commit comments