Skip to content

Commit e5d42fc

Browse files
nagkumar91Nagkumar ArkalgudNagkumar Arkalgud
authored
Context and session_state support for seval (#41488)
* Prepare evals SDK Release * Fix bug * Fix for ADV_CONV for FDP projects * Update release date * Add support for simulator styled callback for safety eval * restore changelog * restore changelog * Update CHANGELOG.md --------- Co-authored-by: Nagkumar Arkalgud <[email protected]> Co-authored-by: Nagkumar Arkalgud <[email protected]>
1 parent 092e4d0 commit e5d42fc

File tree

2 files changed

+73
-44
lines changed

2 files changed

+73
-44
lines changed

sdk/evaluation/azure-ai-evaluation/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## 1.9.0 (Unreleased)
44

5+
### Features Added
6+
57
### Bugs Fixed
68
- Fixed MeteorScoreEvaluator and other threshold-based evaluators returning incorrect binary results due to integer conversion of decimal scores. Previously, decimal scores like 0.9375 were incorrectly converted to integers (0) before threshold comparison, causing them to fail even when above the threshold. [#41415](https://github.com/Azure/azure-sdk-for-python/issues/41415)
79

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -185,47 +185,57 @@ async def _simulate(
185185
:type direct_attack: bool
186186
"""
187187

188-
## Define callback
189-
async def callback(
190-
messages: List[Dict],
191-
stream: bool = False,
192-
session_state: Optional[str] = None,
193-
context: Optional[Dict] = None,
194-
) -> dict:
195-
messages_list = messages["messages"] # type: ignore
196-
latest_message = messages_list[-1]
197-
application_input = latest_message["content"]
198-
context = latest_message.get("context", None)
199-
latest_context = None
200-
try:
201-
is_async = self._is_async_function(target)
202-
if self._check_target_returns_context(target):
203-
if is_async:
204-
response, latest_context = await target(query=application_input)
205-
else:
206-
response, latest_context = target(query=application_input)
207-
else:
208-
if is_async:
209-
response = await target(query=application_input)
188+
## Check if target is already a callback-style function
189+
if self._check_target_is_callback(target):
190+
# Use the target directly as it's already a callback
191+
callback = target
192+
else:
193+
# Define callback wrapper for simple targets
194+
async def callback(
195+
messages: List[Dict],
196+
stream: bool = False,
197+
session_state: Optional[str] = None,
198+
context: Optional[Dict] = None,
199+
) -> dict:
200+
messages_list = messages["messages"] # type: ignore
201+
latest_message = messages_list[-1]
202+
application_input = latest_message["content"]
203+
context = latest_message.get("context", None)
204+
latest_context = None
205+
try:
206+
is_async = self._is_async_function(target)
207+
if self._check_target_returns_context(target):
208+
if is_async:
209+
response, latest_context = await target(
210+
query=application_input
211+
)
212+
else:
213+
response, latest_context = target(
214+
query=application_input
215+
)
210216
else:
211-
response = target(query=application_input)
212-
except Exception as e:
213-
response = f"Something went wrong {e!s}"
214-
215-
## We format the response to follow the openAI chat protocol format
216-
formatted_response = {
217-
"content": response,
218-
"role": "assistant",
219-
"context": latest_context if latest_context else context,
220-
}
221-
## NOTE: In the future, instead of appending to messages we should just return `formatted_response`
222-
messages["messages"].append(formatted_response) # type: ignore
223-
return {
224-
"messages": messages_list,
225-
"stream": stream,
226-
"session_state": session_state,
227-
"context": latest_context if latest_context else context,
228-
}
217+
if is_async:
218+
response = await target(query=application_input)
219+
else:
220+
response = target(query=application_input)
221+
except Exception as e:
222+
response = f"Something went wrong {e!s}"
223+
224+
## We format the response to follow the openAI chat protocol
225+
formatted_response = {
226+
"content": response,
227+
"role": "assistant",
228+
"context": latest_context if latest_context else context,
229+
}
230+
## NOTE: In the future, instead of appending to messages we
231+
## should just return `formatted_response`
232+
messages["messages"].append(formatted_response) # type: ignore
233+
return {
234+
"messages": messages_list,
235+
"stream": stream,
236+
"session_state": session_state,
237+
"context": latest_context if latest_context else context,
238+
}
229239

230240
## Run simulator
231241
simulator = None
@@ -564,7 +574,7 @@ def _is_async_function(target: Callable) -> bool:
564574
def _check_target_is_callback(target: Callable) -> bool:
565575
sig = inspect.signature(target)
566576
param_names = list(sig.parameters.keys())
567-
return 'messages' in param_names and 'stream' in param_names and 'session_state' in param_names and 'context' in param_names
577+
return 'messages' in param_names and 'session_state' in param_names and 'context' in param_names
568578

569579
def _validate_inputs(
570580
self,
@@ -589,9 +599,26 @@ def _validate_inputs(
589599
"""
590600
if not callable(target):
591601
self._validate_model_config(target)
592-
elif not self._check_target_returns_str(target):
593-
self.logger.error(f"Target function {target} does not return a string.")
594-
msg = f"Target function {target} does not return a string."
602+
elif (not self._check_target_is_callback(target) and
603+
not self._check_target_returns_str(target)):
604+
msg = (
605+
f"Invalid target function signature. The target function must be either:\n\n"
606+
f"1. A simple function that takes a 'query' parameter and returns a string:\n"
607+
f" def my_target(query: str) -> str:\n"
608+
f" return f'Response to: {{query}}'\n\n"
609+
f"2. A callback-style function with these exact parameters:\n"
610+
f" async def my_callback(\n"
611+
f" messages: List[Dict],\n"
612+
f" stream: bool = False,\n"
613+
f" session_state: Any = None,\n"
614+
f" context: Any = None\n"
615+
f" ) -> dict:\n"
616+
f" # Process messages and return dict with 'messages', 'stream', 'session_state', 'context'\n"
617+
f" return {{'messages': messages['messages'], 'stream': stream, 'session_state': session_state, 'context': context}}\n\n"
618+
f"Your function '{target.__name__}' does not match either pattern. "
619+
f"Please check the function signature and return type."
620+
)
621+
self.logger.error(msg)
595622
raise EvaluationException(
596623
message=msg,
597624
internal_message=msg,

0 commit comments

Comments
 (0)