@@ -92,7 +92,8 @@ async def __call__(
9292 api_call_delay_sec : float = 1 ,
9393 query_response_generating_prompty_kwargs : Dict [str , Any ] = {},
9494 user_simulator_prompty_kwargs : Dict [str , Any ] = {},
95- conversation_turns : List [List [str ]] = [],
95+ conversation_turns : List [List [Union [str , Dict [str , Any ]]]] = [],
96+ concurrent_async_tasks : int = 5 ,
9697 ** kwargs ,
9798 ) -> List [JsonLineChatProtocol ]:
9899 """
@@ -119,7 +120,10 @@ async def __call__(
119120 :keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
120121 :paramtype user_simulator_prompty_kwargs: Dict[str, Any]
121122 :keyword conversation_turns: Predefined conversation turns to simulate.
122- :paramtype conversation_turns: List[List[str]]
123+ :paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
124+ :keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
125+ Defaults to 5.
126+ :paramtype concurrent_async_tasks: int
123127 :return: A list of simulated conversations represented as JsonLineChatProtocol objects.
124128 :rtype: List[JsonLineChatProtocol]
125129
@@ -134,12 +138,12 @@ async def __call__(
134138 if conversation_turns and (text or tasks ):
135139 raise ValueError ("Cannot specify both conversation_turns and text/tasks" )
136140
137- if num_queries > len (tasks ):
141+ if text and num_queries > len (tasks ):
138142 warnings .warn (
139143 f"You have specified 'num_queries' > len('tasks') ({ num_queries } > { len (tasks )} ). "
140144 f"All tasks will be used for generation and the remaining { num_queries - len (tasks )} lines will be simulated in task-free mode"
141145 )
142- elif num_queries < len (tasks ):
146+ elif text and num_queries < len (tasks ):
143147 warnings .warn (
144148 f"You have specified 'num_queries' < len('tasks') ({ num_queries } < { len (tasks )} ). "
145149 f"Only the first { num_queries } lines of the specified tasks will be simulated."
@@ -157,6 +161,7 @@ async def __call__(
157161 user_simulator_prompty_kwargs = user_simulator_prompty_kwargs ,
158162 api_call_delay_sec = api_call_delay_sec ,
159163 prompty_model_config = prompty_model_config ,
164+ concurrent_async_tasks = concurrent_async_tasks ,
160165 )
161166
162167 query_responses = await self ._generate_query_responses (
@@ -183,11 +188,12 @@ async def _simulate_with_predefined_turns(
183188 * ,
184189 target : Callable ,
185190 max_conversation_turns : int ,
186- conversation_turns : List [List [str ]],
191+ conversation_turns : List [List [Union [ str , Dict [ str , Any ]] ]],
187192 user_simulator_prompty : Optional [str ],
188193 user_simulator_prompty_kwargs : Dict [str , Any ],
189194 api_call_delay_sec : float ,
190195 prompty_model_config : Any ,
196+ concurrent_async_tasks : int ,
191197 ) -> List [JsonLineChatProtocol ]:
192198 """
193199 Simulates conversations using predefined conversation turns.
@@ -197,7 +203,7 @@ async def _simulate_with_predefined_turns(
197203 :keyword max_conversation_turns: Maximum number of turns for the simulation.
198204 :paramtype max_conversation_turns: int
199205 :keyword conversation_turns: A list of predefined conversation turns.
200- :paramtype conversation_turns: List[List[str]]
206+ :paramtype conversation_turns: List[List[Union[ str, Dict[str, Any]] ]]
201207 :keyword user_simulator_prompty: Path to the user simulator prompty file.
202208 :paramtype user_simulator_prompty: Optional[str]
203209 :keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
@@ -206,53 +212,68 @@ async def _simulate_with_predefined_turns(
206212 :paramtype api_call_delay_sec: float
207213 :keyword prompty_model_config: The configuration for the prompty model.
208214 :paramtype prompty_model_config: Any
215+ :keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
216+ :paramtype concurrent_async_tasks: int
209217 :return: A list of simulated conversations represented as JsonLineChatProtocol objects.
210218 :rtype: List[JsonLineChatProtocol]
211219 """
212- simulated_conversations = []
213220 progress_bar = tqdm (
214221 total = int (len (conversation_turns ) * (max_conversation_turns / 2 )),
215222 desc = "Simulating with predefined conversation turns: " ,
216223 ncols = 100 ,
217224 unit = "messages" ,
218225 )
219-
220- for simulation in conversation_turns :
221- current_simulation = ConversationHistory ()
222- for simulated_turn in simulation :
223- user_turn = Turn (role = ConversationRole .USER , content = simulated_turn )
224- current_simulation .add_to_history (user_turn )
225- assistant_response , assistant_context = await self ._get_target_response (
226- target = target , api_call_delay_sec = api_call_delay_sec , conversation_history = current_simulation
227- )
228- assistant_turn = Turn (role = ConversationRole .ASSISTANT , content = assistant_response , context = assistant_context )
229- current_simulation .add_to_history (assistant_turn )
230- progress_bar .update (1 ) # Update progress bar for both user and assistant turns
231-
232- if len (current_simulation ) < max_conversation_turns :
233- await self ._extend_conversation_with_simulator (
234- current_simulation = current_simulation ,
235- max_conversation_turns = max_conversation_turns ,
236- user_simulator_prompty = user_simulator_prompty ,
237- user_simulator_prompty_kwargs = user_simulator_prompty_kwargs ,
238- api_call_delay_sec = api_call_delay_sec ,
239- prompty_model_config = prompty_model_config ,
240- target = target ,
241- progress_bar = progress_bar ,
242- )
243- simulated_conversations .append (
244- JsonLineChatProtocol (
226+ semaphore = asyncio .Semaphore (concurrent_async_tasks )
227+ progress_bar_lock = asyncio .Lock ()
228+
229+ async def run_simulation (simulation : List [Union [str , Dict [str , Any ]]]) -> JsonLineChatProtocol :
230+ async with semaphore :
231+ current_simulation = ConversationHistory ()
232+ for simulated_turn in simulation :
233+ if isinstance (simulated_turn , str ):
234+ user_turn = Turn (role = ConversationRole .USER , content = simulated_turn )
235+ elif isinstance (simulated_turn , dict ):
236+ user_turn = Turn (
237+ role = ConversationRole .USER ,
238+ content = str (simulated_turn .get ("content" )),
239+ context = str (simulated_turn .get ("context" ))
240+ )
241+ else :
242+ raise ValueError ("Each simulated turn must be a string or a dict with 'content' and 'context' keys" )
243+ current_simulation .add_to_history (user_turn )
244+ assistant_response , assistant_context = await self ._get_target_response (
245+ target = target , api_call_delay_sec = api_call_delay_sec , conversation_history = current_simulation
246+ )
247+ assistant_turn = Turn (role = ConversationRole .ASSISTANT , content = assistant_response , context = assistant_context )
248+ current_simulation .add_to_history (assistant_turn )
249+ async with progress_bar_lock :
250+ progress_bar .update (1 )
251+
252+ if len (current_simulation ) < max_conversation_turns :
253+ await self ._extend_conversation_with_simulator (
254+ current_simulation = current_simulation ,
255+ max_conversation_turns = max_conversation_turns ,
256+ user_simulator_prompty = user_simulator_prompty ,
257+ user_simulator_prompty_kwargs = user_simulator_prompty_kwargs ,
258+ api_call_delay_sec = api_call_delay_sec ,
259+ prompty_model_config = prompty_model_config ,
260+ target = target ,
261+ progress_bar = progress_bar ,
262+ progress_bar_lock = progress_bar_lock ,
263+ )
264+ return JsonLineChatProtocol (
245265 {
246266 "messages" : current_simulation .to_list (),
247267 "finish_reason" : ["stop" ],
248268 "context" : {},
249269 "$schema" : "http://azureml/sdk-2-0/ChatConversation.json" ,
250270 }
251271 )
252- )
253272
273+ tasks = [asyncio .create_task (run_simulation (simulation )) for simulation in conversation_turns ]
274+ results = await asyncio .gather (* tasks )
254275 progress_bar .close ()
255- return simulated_conversations
276+ return results
256277
257278 async def _extend_conversation_with_simulator (
258279 self ,
@@ -265,6 +286,7 @@ async def _extend_conversation_with_simulator(
265286 prompty_model_config : Dict [str , Any ],
266287 target : Callable ,
267288 progress_bar : tqdm ,
289+ progress_bar_lock : asyncio .Lock
268290 ):
269291 """
270292 Extends an ongoing conversation using a user simulator until the maximum number of turns is reached.
@@ -285,6 +307,8 @@ async def _extend_conversation_with_simulator(
285307 :paramtype target: Callable,
286308 :keyword progress_bar: Progress bar for tracking simulation progress.
287309 :paramtype progress_bar: tqdm,
310+ :keyword progress_bar_lock: Lock for updating the progress bar safely.
311+ :paramtype progress_bar_lock: asyncio.Lock
288312 """
289313 user_flow = self ._load_user_simulation_flow (
290314 user_simulator_prompty = user_simulator_prompty , # type: ignore
@@ -307,7 +331,8 @@ async def _extend_conversation_with_simulator(
307331 )
308332 assistant_turn = Turn (role = ConversationRole .ASSISTANT , content = assistant_response , context = assistant_context )
309333 current_simulation .add_to_history (assistant_turn )
310- progress_bar .update (1 )
334+ async with progress_bar_lock :
335+ progress_bar .update (1 )
311336
312337 def _load_user_simulation_flow (
313338 self ,
0 commit comments