Skip to content

Commit b5c5845

Browse files
authored
Simulator update to current spec. (#34261)
* Simulator update * Fix spelling
1 parent a482344 commit b5c5845

File tree

17 files changed

+1285
-190
lines changed

17 files changed

+1285
-190
lines changed

sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,5 @@
55
_template_dir = os.path.join(os.path.dirname(__file__), 'templates')
66

77
from .simulator.simulator import Simulator
8-
from .templates.simulator_templates import SimulatorTemplates
98

10-
__all__ = ["Simulator", "SimulatorTemplates"]
9+
__all__ = ["Simulator"]

sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,26 @@
1212
from .constants import ConversationRole
1313

1414

15-
def is_closing_message(response: str):
15+
def is_closing_message(response:Any, recursion_depth: int = 0):
16+
if (recursion_depth > 10):
17+
raise Exception("Exceeded max call depth in is_closing_message")
18+
19+
# recursively go through each inner dictionary in the JSON dict and check if any value entry contains a closing message
20+
if type(response) is dict:
21+
for value in response.values():
22+
if is_closing_message(value, recursion_depth=recursion_depth+1):
23+
return True
24+
elif type(response) is str:
25+
return is_closing_message_helper(response)
26+
27+
return False
28+
29+
def is_closing_message_helper(response: str):
1630
message = response.lower()
1731
if "?" in message.lower():
1832
return False
19-
punctuation = [".", ",", "!", ";", ":"]
20-
for p in punctuation:
33+
punc = [".", ",", "!", ";", ":"]
34+
for p in punc:
2135
message = message.replace(p, "")
2236
if (
2337
"bye" not in message.lower().split()
@@ -36,9 +50,7 @@ async def simulate_conversation(
3650
history_limit: int = 5,
3751
api_call_delay_sec: float = 0,
3852
logger: logging.Logger = logging.getLogger(__name__),
39-
mlflow_logger: Optional[Any] = None,
40-
template_paramaters: Optional[dict] = None,
41-
simulate_callback: Optional[Callable[[str, Sequence[Union[Dict, ConversationTurn]], Optional[dict]], str]] = None,
53+
mlflow_logger=None,
4254
):
4355
"""
4456
Simulate a conversation between the given bots.
@@ -82,45 +94,30 @@ async def simulate_conversation(
8294
(current_turn < turn_limit)
8395
):
8496
try:
85-
current_character_idx = current_turn % 2
86-
# if there is only one bot, means using customized simulate callback
87-
# in the customer bot turn, instead of using the customer bot, need to invoke the simulate callback
88-
if len(bots) < 2 and current_character_idx == 1:
89-
question = conversation_history[-1].message
90-
# TODO: Fix Bug 2816997
91-
response = await simulate_callback(question, conversation_history, template_paramaters) # type: ignore[misc]
92-
# add the generated response to the list of generated responses
93-
conversation_history.append(
94-
ConversationTurn(
95-
role=ConversationRole.ASSISTANT,
96-
name="ChatBot",
97-
message=response,
98-
))
99-
else:
100-
current_bot = bots[current_character_idx]
101-
# invoke Bot to generate response given the input request
102-
logger.info(f"-- Sending to {current_bot.role.value}")
103-
# pass only the last generated turn without passing the bot name.
104-
response, request, time_taken, full_response = await current_bot.generate_response(
105-
session=session,
106-
conversation_history=conversation_history,
107-
max_history=history_limit,
108-
turn_number=current_turn,
109-
)
110-
# add the generated response to the list of generated responses
111-
conversation_history.append(
112-
ConversationTurn(
113-
role=current_bot.role,
114-
name=current_bot.name,
115-
message=response["samples"][0],
116-
full_response=full_response,
117-
request=request,
118-
))
97+
current_character_idx = current_turn % len(bots)
98+
current_bot = bots[current_character_idx]
99+
# invoke Bot to generate response given the input request
100+
logger.info(f"-- Sending to {current_bot.role.value}")
101+
# pass only the last generated turn without passing the bot name.
102+
response, request, time_taken, full_response = await current_bot.generate_response(
103+
session=session,
104+
conversation_history=conversation_history,
105+
max_history=history_limit,
106+
turn_number=current_turn,
107+
)
119108

120109
# check if conversation id is null, which means conversation starter was used. use id from next turn
121110
if conversation_id is None and 'id' in response:
122111
conversation_id = response["id"]
123-
112+
# add the generated response to the list of generated responses
113+
conversation_history.append(
114+
ConversationTurn(
115+
role=current_bot.role,
116+
name=current_bot.name,
117+
message=response["samples"][0],
118+
full_response=full_response,
119+
request=request,
120+
))
124121
logger.info(f"Last turn: {conversation_history[-1]}")
125122
if mlflow_logger is not None:
126123
logger_tasks.append( # schedule logging but don't get blocked by it
@@ -129,8 +126,7 @@ async def simulate_conversation(
129126
)
130127
)
131128
except Exception as e:
132-
logger.warning(f"Error: {e}")
133-
raise e
129+
logger.warning("Error:" + str(e))
134130
if mlflow_logger is not None:
135131
logger_tasks.append( # schedule logging but don't get blocked by it
136132
asyncio.create_task(mlflow_logger.log_error())

sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_conversation/conversation_bot.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838

3939

4040
self.role = role
41+
self.conversation_template_orig = conversation_template
4142
self.conversation_template: jinja2.Template = jinja2.Template(
4243
conversation_template, undefined=jinja2.StrictUndefined
4344
)
@@ -50,17 +51,25 @@ def __init__(
5051

5152
self.logger = logging.getLogger(repr(self))
5253

54+
self.conversation_starter = None # can either be a dictionary or jinja template
5355
if role == ConversationRole.USER:
5456
if "conversation_starter" in self.persona_template_args:
55-
self.logger.info(
56-
'This simulated bot will use the provided conversation starter '
57-
f'"{repr(self.persona_template_args["conversation_starter"])[:400]}"'
58-
'instead of generating a turn using a LLM'
59-
)
60-
self.conversation_starter = self.persona_template_args["conversation_starter"]
57+
conversation_starter_content = self.persona_template_args["conversation_starter"]
58+
if type(conversation_starter_content) is dict:
59+
self.logger.info(f'This simulated bot will use the provided conversation starter (passed in as dictionary): {conversation_starter_content} instead of generating a turn using a LLM')
60+
self.conversation_starter = conversation_starter_content
61+
else:
62+
self.logger.info(
63+
'This simulated bot will use the provided conversation starter '
64+
f'{repr(conversation_starter_content)[:400]}'
65+
' instead of generating a turn using a LLM'
66+
)
67+
self.conversation_starter = jinja2.Template(
68+
conversation_starter_content, undefined=jinja2.StrictUndefined
69+
)
6170
else:
6271
self.logger.info('This simulated bot will generate the first turn as no conversation starter is provided')
63-
self.conversation_starter = ""
72+
6473

6574

6675
async def generate_response(
@@ -88,11 +97,16 @@ async def generate_response(
8897

8998
# check if this is the first turn and the conversation_starter is not None,
9099
# return the conversations starter rather than generating turn using LLM
91-
if turn_number == 0 and self.conversation_starter is not None and self.conversation_starter != "":
92-
self.logger.info(f"Returning conversation starter: {self.conversation_starter}")
100+
if turn_number == 0 and self.conversation_starter is not None:
101+
# if conversation_starter is a dictionary, pass it into samples as is
102+
if type(self.conversation_starter) is dict:
103+
self.logger.info(f"Returning conversation starter: {self.conversation_starter}")
104+
samples = [self.conversation_starter]
105+
else:
106+
self.logger.info(f"Returning conversation starter: {repr(self.persona_template_args['conversation_starter'])[:400]}")
107+
samples = [self.conversation_starter.render(**self.persona_template_args)] # type: ignore[attr-defined]
93108
time_taken = 0
94109

95-
samples = [self.conversation_starter]
96110
finish_reason = ["stop"]
97111

98112
parsed_response = {
@@ -103,11 +117,15 @@ async def generate_response(
103117
full_response = parsed_response
104118
return parsed_response, {}, time_taken, full_response
105119

106-
prompt = self.conversation_template.render(
107-
conversation_turns=conversation_history[-max_history:],
108-
role=self.role.value,
109-
**self.persona_template_args
110-
)
120+
try:
121+
prompt = self.conversation_template.render(
122+
conversation_turns=conversation_history[-max_history:],
123+
role=self.role.value,
124+
**self.persona_template_args
125+
)
126+
except:
127+
import code
128+
code.interact(local=locals())
111129

112130
messages = [{"role": "system", "content": prompt}]
113131

sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/simulator/_model_tools/models.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_model_class_from_url(endpoint_url: str):
4040

4141
# ===================== HTTP Retry ======================
4242
class AsyncHTTPClientWithRetry:
43-
def __init__(self, n_retry, retry_timeout, logger):
43+
def __init__(self, n_retry, retry_timeout, logger, retry_options=None):
4444
self.attempts = n_retry
4545
self.logger = logger
4646

@@ -49,14 +49,14 @@ def __init__(self, n_retry, retry_timeout, logger):
4949
trace_config = TraceConfig() # set up request logging
5050
trace_config.on_request_start.append(self.on_request_start)
5151
trace_config.on_request_end.append(self.on_request_end)
52-
53-
retry_options = RandomRetry( # set up retry configuration
54-
statuses=[104, 408, 409, 424, 429, 500, 502,
55-
503, 504], # on which statuses to retry
56-
attempts=n_retry,
57-
min_timeout=retry_timeout,
58-
max_timeout=retry_timeout,
59-
)
52+
if retry_options is None:
53+
retry_options = RandomRetry( # set up retry configuration
54+
statuses=[104, 408, 409, 424, 429, 500, 502,
55+
503, 504], # on which statuses to retry
56+
attempts=n_retry,
57+
min_timeout=retry_timeout,
58+
max_timeout=retry_timeout,
59+
)
6060

6161
self.client = RetryClient(
6262
trace_configs=[trace_config], retry_options=retry_options)
@@ -641,6 +641,7 @@ def _parse_response(self, response_data: dict, request_data: Optional[dict] = No
641641
# https://platform.openai.com/docs/api-reference/chat
642642
samples = []
643643
finish_reason = []
644+
644645
for choice in response_data["choices"]:
645646
if 'message' in choice and 'content' in choice['message']:
646647
samples.append(choice['message']['content'])
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
from azure.ai.generative.synthetic.simulator._model_tools.models import (
5+
AsyncHTTPClientWithRetry,
6+
)
7+
from aiohttp_retry import JitterRetry
8+
import logging
9+
10+
import os
11+
12+
api_url = None
13+
if "rai_svc_url" in os.environ:
14+
api_url = os.environ["rai_svc_url"]
15+
api_url = api_url.rstrip("/")
16+
print(
17+
f"Found rai_svc_url in environment variable, using {api_url} for rai service endpoint."
18+
)
19+
20+
21+
class RAIClient:
22+
def __init__(self, ml_client, token_manager):
23+
self.ml_client = ml_client
24+
self.token_manager = token_manager
25+
26+
self.contentharm_parameters = None
27+
self.jailbreaks_dataset = None
28+
29+
if api_url is not None:
30+
host = api_url
31+
else:
32+
host = self.ml_client.jobs._api_url
33+
34+
self.api_url = (
35+
f"{host}/"
36+
+ f"raisvc/v1.0/subscriptions/{self.ml_client.subscription_id}/"
37+
+ f"resourceGroups/{self.ml_client.resource_group_name}/"
38+
+ f"providers/Microsoft.MachineLearningServices/workspaces/{self.ml_client.workspace_name}/"
39+
)
40+
41+
self.parameter_json_endpoint = self.api_url + "simulation/template/parameters"
42+
self.jailbreaks_json_endpoint = self.api_url + "simulation/jailbreak"
43+
self.simulation_submit_endpoint = (
44+
self.api_url + "simulation/chat/completions/submit"
45+
)
46+
47+
def _create_async_client(self):
48+
return AsyncHTTPClientWithRetry(
49+
n_retry=6, retry_timeout=5, logger=logging.getLogger()
50+
)
51+
52+
async def get_contentharm_parameters(self):
53+
if self.contentharm_parameters is None:
54+
self.contentharm_parameters = await self.get(self.parameter_json_endpoint)
55+
56+
return self.contentharm_parameters
57+
58+
async def get_jailbreaks_dataset(self):
59+
if self.jailbreaks_dataset is None:
60+
self.jailbreaks_dataset = await self.get(self.jailbreaks_json_endpoint)
61+
62+
return self.jailbreaks_dataset
63+
64+
async def get(self, url):
65+
token = await self.token_manager.get_token()
66+
headers = {
67+
"Authorization": f"Bearer {token}",
68+
"Content-Type": "application/json",
69+
}
70+
71+
async with self._create_async_client().client as session:
72+
async with session.get(url=url, headers=headers) as response:
73+
if response.status == 200:
74+
response = await response.json()
75+
return response
76+
77+
raise ValueError("Unable to retrieve requested resource from rai service.")
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
from azure.ai.generative.synthetic.simulator._conversation import (
6+
ConversationBot,
7+
ConversationRole,
8+
ConversationTurn,
9+
simulate_conversation,
10+
)
11+
12+
import copy
13+
from typing import List, Tuple
14+
15+
16+
class CallbackConversationBot(ConversationBot):
17+
def __init__(
18+
self, callback, user_template, user_template_parameters, *args, **kwargs
19+
):
20+
self.callback = callback
21+
self.user_template = user_template
22+
self.user_template_parameters = user_template_parameters
23+
24+
super().__init__(*args, **kwargs)
25+
26+
async def generate_response(
27+
self,
28+
session: "RetryClient", # type: ignore[name-defined]
29+
conversation_history: List[ConversationTurn],
30+
max_history: int,
31+
turn_number: int = 0,
32+
) -> Tuple[dict, dict, int, dict]:
33+
chat_protocol_message = self._to_chat_protocol(
34+
self.user_template, conversation_history, self.user_template_parameters
35+
)
36+
msg_copy = copy.deepcopy(chat_protocol_message)
37+
result = await self.callback(msg_copy)
38+
39+
self.logger.info(f"Using user provided callback returning response.")
40+
41+
time_taken = 0
42+
try:
43+
response = {
44+
"samples": [result["messages"][-1]["content"]],
45+
"finish_reason": ["stop"],
46+
"id": None,
47+
}
48+
except:
49+
raise TypeError(
50+
"User provided callback do not conform to chat protocol standard."
51+
)
52+
53+
self.logger.info(f"Parsed callback response")
54+
55+
return response, {}, time_taken, response
56+
57+
def _to_chat_protocol(self, template, conversation_history, template_parameters):
58+
messages = []
59+
60+
for i, m in enumerate(conversation_history):
61+
messages.append({"content": m.message, "role": m.role.value})
62+
63+
return {
64+
"template_parameters": template_parameters,
65+
"messages": messages,
66+
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
67+
}

0 commit comments

Comments
 (0)