@@ -185,47 +185,57 @@ async def _simulate(
185185 :type direct_attack: bool
186186 """
187187
188- ## Define callback
189- async def callback (
190- messages : List [Dict ],
191- stream : bool = False ,
192- session_state : Optional [str ] = None ,
193- context : Optional [Dict ] = None ,
194- ) -> dict :
195- messages_list = messages ["messages" ] # type: ignore
196- latest_message = messages_list [- 1 ]
197- application_input = latest_message ["content" ]
198- context = latest_message .get ("context" , None )
199- latest_context = None
200- try :
201- is_async = self ._is_async_function (target )
202- if self ._check_target_returns_context (target ):
203- if is_async :
204- response , latest_context = await target (query = application_input )
205- else :
206- response , latest_context = target (query = application_input )
207- else :
208- if is_async :
209- response = await target (query = application_input )
188+ ## Check if target is already a callback-style function
189+ if self ._check_target_is_callback (target ):
190+ # Use the target directly as it's already a callback
191+ callback = target
192+ else :
193+ # Define callback wrapper for simple targets
194+ async def callback (
195+ messages : List [Dict ],
196+ stream : bool = False ,
197+ session_state : Optional [str ] = None ,
198+ context : Optional [Dict ] = None ,
199+ ) -> dict :
200+ messages_list = messages ["messages" ] # type: ignore
201+ latest_message = messages_list [- 1 ]
202+ application_input = latest_message ["content" ]
203+ context = latest_message .get ("context" , None )
204+ latest_context = None
205+ try :
206+ is_async = self ._is_async_function (target )
207+ if self ._check_target_returns_context (target ):
208+ if is_async :
209+ response , latest_context = await target (
210+ query = application_input
211+ )
212+ else :
213+ response , latest_context = target (
214+ query = application_input
215+ )
210216 else :
211- response = target (query = application_input )
212- except Exception as e :
213- response = f"Something went wrong { e !s} "
214-
215- ## We format the response to follow the openAI chat protocol format
216- formatted_response = {
217- "content" : response ,
218- "role" : "assistant" ,
219- "context" : latest_context if latest_context else context ,
220- }
221- ## NOTE: In the future, instead of appending to messages we should just return `formatted_response`
222- messages ["messages" ].append (formatted_response ) # type: ignore
223- return {
224- "messages" : messages_list ,
225- "stream" : stream ,
226- "session_state" : session_state ,
227- "context" : latest_context if latest_context else context ,
228- }
217+ if is_async :
218+ response = await target (query = application_input )
219+ else :
220+ response = target (query = application_input )
221+ except Exception as e :
222+ response = f"Something went wrong { e !s} "
223+
224+ ## We format the response to follow the openAI chat protocol
225+ formatted_response = {
226+ "content" : response ,
227+ "role" : "assistant" ,
228+ "context" : latest_context if latest_context else context ,
229+ }
230+ ## NOTE: In the future, instead of appending to messages we
231+ ## should just return `formatted_response`
232+ messages ["messages" ].append (formatted_response ) # type: ignore
233+ return {
234+ "messages" : messages_list ,
235+ "stream" : stream ,
236+ "session_state" : session_state ,
237+ "context" : latest_context if latest_context else context ,
238+ }
229239
230240 ## Run simulator
231241 simulator = None
@@ -564,7 +574,7 @@ def _is_async_function(target: Callable) -> bool:
564574 def _check_target_is_callback (target : Callable ) -> bool :
565575 sig = inspect .signature (target )
566576 param_names = list (sig .parameters .keys ())
567- return 'messages' in param_names and 'stream' in param_names and ' session_state' in param_names and 'context' in param_names
577+ return 'messages' in param_names and 'session_state' in param_names and 'context' in param_names
568578
569579 def _validate_inputs (
570580 self ,
@@ -589,9 +599,26 @@ def _validate_inputs(
589599 """
590600 if not callable (target ):
591601 self ._validate_model_config (target )
592- elif not self ._check_target_returns_str (target ):
593- self .logger .error (f"Target function { target } does not return a string." )
594- msg = f"Target function { target } does not return a string."
602+ elif (not self ._check_target_is_callback (target ) and
603+ not self ._check_target_returns_str (target )):
604+ msg = (
605+ f"Invalid target function signature. The target function must be either:\n \n "
606+ f"1. A simple function that takes a 'query' parameter and returns a string:\n "
607+ f" def my_target(query: str) -> str:\n "
608+ f" return f'Response to: {{query}}'\n \n "
609+ f"2. A callback-style function with these exact parameters:\n "
610+ f" async def my_callback(\n "
611+ f" messages: List[Dict],\n "
612+ f" stream: bool = False,\n "
613+ f" session_state: Any = None,\n "
614+ f" context: Any = None\n "
615+ f" ) -> dict:\n "
616+ f" # Process messages and return dict with 'messages', 'stream', 'session_state', 'context'\n "
617+ f" return {{'messages': messages['messages'], 'stream': stream, 'session_state': session_state, 'context': context}}\n \n "
618+ f"Your function '{ target .__name__ } ' does not match either pattern. "
619+ f"Please check the function signature and return type."
620+ )
621+ self .logger .error (msg )
595622 raise EvaluationException (
596623 message = msg ,
597624 internal_message = msg ,
0 commit comments