Skip to content

Commit 3607859

Browse files
nagkumar91Nagkumar ArkalgudNagkumar ArkalgudNagkumar Arkalgud
authored
Groundedness detection support in Non adversarial simulator (Azure#38087)
* Update task_query_response.prompty remove required keys * Update task_simulate.prompty * Update task_query_response.prompty * Update task_simulate.prompty * Fix the api_key needed * Update for release * Black fix for file * Add original text in global context * Update test * Update the indirect attack simulator * Black suggested fixes * Update simulator prompty * Update adversarial scenario enum to exclude XPIA * Update changelog * Black fixes * Remove duplicate import * Fix the mypy error * Mypy please be happy * Updates to non adv simulator * accept context from assistant messages, exclude them when using them for conversation * update changelog * pylint fixes * pylint fixes * remove redundant quotes * Fix typo * pylint fix * Update broken tests * Non adv simulator accepts object for conversation starter * Fix warning message being displayed incoorectly * Adding the grounding json in package resource and update changelog * Fix typo * Added the grouding file * Update test, add concurrent async runs to conversation_starter based simualtor * Add grounding json to cspell ignore --------- Co-authored-by: Nagkumar Arkalgud <[email protected]> Co-authored-by: Nagkumar Arkalgud <[email protected]> Co-authored-by: Nagkumar Arkalgud <[email protected]>
1 parent e6590e0 commit 3607859

File tree

6 files changed

+1234
-36
lines changed

6 files changed

+1234
-36
lines changed

sdk/evaluation/azure-ai-evaluation/CHANGELOG.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,24 @@
44
## 1.0.0b5 (Unreleased)
55

66
### Features Added
7+
- Groundedness detection in Non Adversarial Simulator via query/context pairs
8+
```python
9+
import importlib.resources as pkg_resources
10+
package = "azure.ai.evaluation.simulator._data_sources"
11+
resource_name = "grounding.json"
12+
custom_simulator = Simulator(model_config=model_config)
13+
conversation_turns = []
14+
with pkg_resources.path(package, resource_name) as grounding_file:
15+
with open(grounding_file, "r") as file:
16+
data = json.load(file)
17+
for item in data:
18+
conversation_turns.append([item])
19+
outputs = asyncio.run(custom_simulator(
20+
target=callback,
21+
conversation_turns=conversation_turns,
22+
max_conversation_turns=1,
23+
))
24+
```
725

826
### Breaking Changes
927
- Renamed environment variable `PF_EVALS_BATCH_USE_ASYNC` to `AI_EVALS_BATCH_USE_ASYNC`.

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_data_sources/grounding.json

Lines changed: 1150 additions & 0 deletions
Large diffs are not rendered by default.

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_simulator.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ async def __call__(
9292
api_call_delay_sec: float = 1,
9393
query_response_generating_prompty_kwargs: Dict[str, Any] = {},
9494
user_simulator_prompty_kwargs: Dict[str, Any] = {},
95-
conversation_turns: List[List[str]] = [],
95+
conversation_turns: List[List[Union[str, Dict[str, Any]]]] = [],
96+
concurrent_async_tasks: int = 5,
9697
**kwargs,
9798
) -> List[JsonLineChatProtocol]:
9899
"""
@@ -119,7 +120,10 @@ async def __call__(
119120
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
120121
:paramtype user_simulator_prompty_kwargs: Dict[str, Any]
121122
:keyword conversation_turns: Predefined conversation turns to simulate.
122-
:paramtype conversation_turns: List[List[str]]
123+
:paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
124+
:keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
125+
Defaults to 5.
126+
:paramtype concurrent_async_tasks: int
123127
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
124128
:rtype: List[JsonLineChatProtocol]
125129
@@ -134,12 +138,12 @@ async def __call__(
134138
if conversation_turns and (text or tasks):
135139
raise ValueError("Cannot specify both conversation_turns and text/tasks")
136140

137-
if num_queries > len(tasks):
141+
if text and num_queries > len(tasks):
138142
warnings.warn(
139143
f"You have specified 'num_queries' > len('tasks') ({num_queries} > {len(tasks)}). "
140144
f"All tasks will be used for generation and the remaining {num_queries - len(tasks)} lines will be simulated in task-free mode"
141145
)
142-
elif num_queries < len(tasks):
146+
elif text and num_queries < len(tasks):
143147
warnings.warn(
144148
f"You have specified 'num_queries' < len('tasks') ({num_queries} < {len(tasks)}). "
145149
f"Only the first {num_queries} lines of the specified tasks will be simulated."
@@ -157,6 +161,7 @@ async def __call__(
157161
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
158162
api_call_delay_sec=api_call_delay_sec,
159163
prompty_model_config=prompty_model_config,
164+
concurrent_async_tasks=concurrent_async_tasks,
160165
)
161166

162167
query_responses = await self._generate_query_responses(
@@ -183,11 +188,12 @@ async def _simulate_with_predefined_turns(
183188
*,
184189
target: Callable,
185190
max_conversation_turns: int,
186-
conversation_turns: List[List[str]],
191+
conversation_turns: List[List[Union[str, Dict[str, Any]]]],
187192
user_simulator_prompty: Optional[str],
188193
user_simulator_prompty_kwargs: Dict[str, Any],
189194
api_call_delay_sec: float,
190195
prompty_model_config: Any,
196+
concurrent_async_tasks: int,
191197
) -> List[JsonLineChatProtocol]:
192198
"""
193199
Simulates conversations using predefined conversation turns.
@@ -197,7 +203,7 @@ async def _simulate_with_predefined_turns(
197203
:keyword max_conversation_turns: Maximum number of turns for the simulation.
198204
:paramtype max_conversation_turns: int
199205
:keyword conversation_turns: A list of predefined conversation turns.
200-
:paramtype conversation_turns: List[List[str]]
206+
:paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
201207
:keyword user_simulator_prompty: Path to the user simulator prompty file.
202208
:paramtype user_simulator_prompty: Optional[str]
203209
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
@@ -206,53 +212,68 @@ async def _simulate_with_predefined_turns(
206212
:paramtype api_call_delay_sec: float
207213
:keyword prompty_model_config: The configuration for the prompty model.
208214
:paramtype prompty_model_config: Any
215+
:keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
216+
:paramtype concurrent_async_tasks: int
209217
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
210218
:rtype: List[JsonLineChatProtocol]
211219
"""
212-
simulated_conversations = []
213220
progress_bar = tqdm(
214221
total=int(len(conversation_turns) * (max_conversation_turns / 2)),
215222
desc="Simulating with predefined conversation turns: ",
216223
ncols=100,
217224
unit="messages",
218225
)
219-
220-
for simulation in conversation_turns:
221-
current_simulation = ConversationHistory()
222-
for simulated_turn in simulation:
223-
user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
224-
current_simulation.add_to_history(user_turn)
225-
assistant_response, assistant_context = await self._get_target_response(
226-
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
227-
)
228-
assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context)
229-
current_simulation.add_to_history(assistant_turn)
230-
progress_bar.update(1) # Update progress bar for both user and assistant turns
231-
232-
if len(current_simulation) < max_conversation_turns:
233-
await self._extend_conversation_with_simulator(
234-
current_simulation=current_simulation,
235-
max_conversation_turns=max_conversation_turns,
236-
user_simulator_prompty=user_simulator_prompty,
237-
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
238-
api_call_delay_sec=api_call_delay_sec,
239-
prompty_model_config=prompty_model_config,
240-
target=target,
241-
progress_bar=progress_bar,
242-
)
243-
simulated_conversations.append(
244-
JsonLineChatProtocol(
226+
semaphore = asyncio.Semaphore(concurrent_async_tasks)
227+
progress_bar_lock = asyncio.Lock()
228+
229+
async def run_simulation(simulation: List[Union[str, Dict[str, Any]]]) -> JsonLineChatProtocol:
230+
async with semaphore:
231+
current_simulation = ConversationHistory()
232+
for simulated_turn in simulation:
233+
if isinstance(simulated_turn, str):
234+
user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
235+
elif isinstance(simulated_turn, dict):
236+
user_turn = Turn(
237+
role=ConversationRole.USER,
238+
content=str(simulated_turn.get("content")),
239+
context=str(simulated_turn.get("context"))
240+
)
241+
else:
242+
raise ValueError("Each simulated turn must be a string or a dict with 'content' and 'context' keys")
243+
current_simulation.add_to_history(user_turn)
244+
assistant_response, assistant_context = await self._get_target_response(
245+
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
246+
)
247+
assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context)
248+
current_simulation.add_to_history(assistant_turn)
249+
async with progress_bar_lock:
250+
progress_bar.update(1)
251+
252+
if len(current_simulation) < max_conversation_turns:
253+
await self._extend_conversation_with_simulator(
254+
current_simulation=current_simulation,
255+
max_conversation_turns=max_conversation_turns,
256+
user_simulator_prompty=user_simulator_prompty,
257+
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
258+
api_call_delay_sec=api_call_delay_sec,
259+
prompty_model_config=prompty_model_config,
260+
target=target,
261+
progress_bar=progress_bar,
262+
progress_bar_lock=progress_bar_lock,
263+
)
264+
return JsonLineChatProtocol(
245265
{
246266
"messages": current_simulation.to_list(),
247267
"finish_reason": ["stop"],
248268
"context": {},
249269
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
250270
}
251271
)
252-
)
253272

273+
tasks = [asyncio.create_task(run_simulation(simulation)) for simulation in conversation_turns]
274+
results = await asyncio.gather(*tasks)
254275
progress_bar.close()
255-
return simulated_conversations
276+
return results
256277

257278
async def _extend_conversation_with_simulator(
258279
self,
@@ -265,6 +286,7 @@ async def _extend_conversation_with_simulator(
265286
prompty_model_config: Dict[str, Any],
266287
target: Callable,
267288
progress_bar: tqdm,
289+
progress_bar_lock: asyncio.Lock
268290
):
269291
"""
270292
Extends an ongoing conversation using a user simulator until the maximum number of turns is reached.
@@ -285,6 +307,8 @@ async def _extend_conversation_with_simulator(
285307
:paramtype target: Callable,
286308
:keyword progress_bar: Progress bar for tracking simulation progress.
287309
:paramtype progress_bar: tqdm,
310+
:keyword progress_bar_lock: Lock for updating the progress bar safely.
311+
:paramtype progress_bar_lock: asyncio.Lock
288312
"""
289313
user_flow = self._load_user_simulation_flow(
290314
user_simulator_prompty=user_simulator_prompty, # type: ignore
@@ -307,7 +331,8 @@ async def _extend_conversation_with_simulator(
307331
)
308332
assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context)
309333
current_simulation.add_to_history(assistant_turn)
310-
progress_bar.update(1)
334+
async with progress_bar_lock:
335+
progress_bar.update(1)
311336

312337
def _load_user_simulation_flow(
313338
self,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"ignorePaths": ["sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_data_sources/grounding.json"]
3+
}

sdk/evaluation/azure-ai-evaluation/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,6 @@
8585
package_data={
8686
"pytyped": ["py.typed"],
8787
"azure.ai.evaluation.simulator._prompty": ["*.prompty"],
88+
"azure.ai.evaluation.simulator._data_sources": ["*.json"],
8889
},
8990
)

sdk/evaluation/azure-ai-evaluation/tests/unittests/test_non_adv_simulator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ async def test_simulate_with_predefined_turns(
330330
prompty_model_config={},
331331
user_simulator_prompty=None,
332332
user_simulator_prompty_kwargs={},
333+
concurrent_async_tasks=1,
333334
)
334335

335336
assert len(result) == 1

0 commit comments

Comments
 (0)