@@ -185,47 +185,57 @@ async def _simulate(
185
185
:type direct_attack: bool
186
186
"""
187
187
188
- ## Define callback
189
- async def callback (
190
- messages : List [Dict ],
191
- stream : bool = False ,
192
- session_state : Optional [str ] = None ,
193
- context : Optional [Dict ] = None ,
194
- ) -> dict :
195
- messages_list = messages ["messages" ] # type: ignore
196
- latest_message = messages_list [- 1 ]
197
- application_input = latest_message ["content" ]
198
- context = latest_message .get ("context" , None )
199
- latest_context = None
200
- try :
201
- is_async = self ._is_async_function (target )
202
- if self ._check_target_returns_context (target ):
203
- if is_async :
204
- response , latest_context = await target (query = application_input )
205
- else :
206
- response , latest_context = target (query = application_input )
207
- else :
208
- if is_async :
209
- response = await target (query = application_input )
188
+ ## Check if target is already a callback-style function
189
+ if self ._check_target_is_callback (target ):
190
+ # Use the target directly as it's already a callback
191
+ callback = target
192
+ else :
193
+ # Define callback wrapper for simple targets
194
+ async def callback (
195
+ messages : List [Dict ],
196
+ stream : bool = False ,
197
+ session_state : Optional [str ] = None ,
198
+ context : Optional [Dict ] = None ,
199
+ ) -> dict :
200
+ messages_list = messages ["messages" ] # type: ignore
201
+ latest_message = messages_list [- 1 ]
202
+ application_input = latest_message ["content" ]
203
+ context = latest_message .get ("context" , None )
204
+ latest_context = None
205
+ try :
206
+ is_async = self ._is_async_function (target )
207
+ if self ._check_target_returns_context (target ):
208
+ if is_async :
209
+ response , latest_context = await target (
210
+ query = application_input
211
+ )
212
+ else :
213
+ response , latest_context = target (
214
+ query = application_input
215
+ )
210
216
else :
211
- response = target (query = application_input )
212
- except Exception as e :
213
- response = f"Something went wrong { e !s} "
214
-
215
- ## We format the response to follow the openAI chat protocol format
216
- formatted_response = {
217
- "content" : response ,
218
- "role" : "assistant" ,
219
- "context" : latest_context if latest_context else context ,
220
- }
221
- ## NOTE: In the future, instead of appending to messages we should just return `formatted_response`
222
- messages ["messages" ].append (formatted_response ) # type: ignore
223
- return {
224
- "messages" : messages_list ,
225
- "stream" : stream ,
226
- "session_state" : session_state ,
227
- "context" : latest_context if latest_context else context ,
228
- }
217
+ if is_async :
218
+ response = await target (query = application_input )
219
+ else :
220
+ response = target (query = application_input )
221
+ except Exception as e :
222
+ response = f"Something went wrong { e !s} "
223
+
224
+ ## We format the response to follow the openAI chat protocol
225
+ formatted_response = {
226
+ "content" : response ,
227
+ "role" : "assistant" ,
228
+ "context" : latest_context if latest_context else context ,
229
+ }
230
+ ## NOTE: In the future, instead of appending to messages we
231
+ ## should just return `formatted_response`
232
+ messages ["messages" ].append (formatted_response ) # type: ignore
233
+ return {
234
+ "messages" : messages_list ,
235
+ "stream" : stream ,
236
+ "session_state" : session_state ,
237
+ "context" : latest_context if latest_context else context ,
238
+ }
229
239
230
240
## Run simulator
231
241
simulator = None
@@ -564,7 +574,7 @@ def _is_async_function(target: Callable) -> bool:
564
574
def _check_target_is_callback (target : Callable ) -> bool :
565
575
sig = inspect .signature (target )
566
576
param_names = list (sig .parameters .keys ())
567
- return 'messages' in param_names and 'stream' in param_names and ' session_state' in param_names and 'context' in param_names
577
+ return 'messages' in param_names and 'session_state' in param_names and 'context' in param_names
568
578
569
579
def _validate_inputs (
570
580
self ,
@@ -589,9 +599,26 @@ def _validate_inputs(
589
599
"""
590
600
if not callable (target ):
591
601
self ._validate_model_config (target )
592
- elif not self ._check_target_returns_str (target ):
593
- self .logger .error (f"Target function { target } does not return a string." )
594
- msg = f"Target function { target } does not return a string."
602
+ elif (not self ._check_target_is_callback (target ) and
603
+ not self ._check_target_returns_str (target )):
604
+ msg = (
605
+ f"Invalid target function signature. The target function must be either:\n \n "
606
+ f"1. A simple function that takes a 'query' parameter and returns a string:\n "
607
+ f" def my_target(query: str) -> str:\n "
608
+ f" return f'Response to: {{query}}'\n \n "
609
+ f"2. A callback-style function with these exact parameters:\n "
610
+ f" async def my_callback(\n "
611
+ f" messages: List[Dict],\n "
612
+ f" stream: bool = False,\n "
613
+ f" session_state: Any = None,\n "
614
+ f" context: Any = None\n "
615
+ f" ) -> dict:\n "
616
+ f" # Process messages and return dict with 'messages', 'stream', 'session_state', 'context'\n "
617
+ f" return {{'messages': messages['messages'], 'stream': stream, 'session_state': session_state, 'context': context}}\n \n "
618
+ f"Your function '{ target .__name__ } ' does not match either pattern. "
619
+ f"Please check the function signature and return type."
620
+ )
621
+ self .logger .error (msg )
595
622
raise EvaluationException (
596
623
message = msg ,
597
624
internal_message = msg ,
0 commit comments