diff --git a/camel/societies/role_playing.py b/camel/societies/role_playing.py index 60b16fb4d0..699ef10a47 100644 --- a/camel/societies/role_playing.py +++ b/camel/societies/role_playing.py @@ -81,6 +81,12 @@ class RolePlaying: stop_event (Optional[threading.Event], optional): Event to signal termination of the agent's operation. When set, the agent will terminate its execution. (default: :obj:`None`) + assistant_agent (ChatAgent, optional): A pre-configured ChatAgent to use as the + assistant. If provided, this will override the creation of a new assistant agent. + (default: :obj:`None`) + user_agent (ChatAgent, optional): A pre-configured ChatAgent to use as the + user. If provided, this will override the creation of a new user agent. + (default: :obj:`None`) """ def __init__( @@ -106,6 +112,8 @@ def __init__( extend_task_specify_meta_dict: Optional[Dict] = None, output_language: Optional[str] = None, stop_event: Optional[threading.Event] = None, + assistant_agent: Optional[ChatAgent] = None, + user_agent: Optional[ChatAgent] = None, ) -> None: if model is not None: logger.warning( @@ -143,21 +151,27 @@ def __init__( **(sys_msg_generator_kwargs or {}), ) - ( - init_assistant_sys_msg, - init_user_sys_msg, - sys_msg_meta_dicts, - ) = self._get_sys_message_info( - assistant_role_name, - user_role_name, - sys_msg_generator, - extend_sys_msg_meta_dicts=extend_sys_msg_meta_dicts, - ) - self.assistant_agent: ChatAgent self.user_agent: ChatAgent self.assistant_sys_msg: Optional[BaseMessage] self.user_sys_msg: Optional[BaseMessage] + + # Do not regenerate sys_msg when using custom agents + if self.assistant_agent is not None and self.user_agent is not None: + ( + init_assistant_sys_msg, + init_user_sys_msg, + sys_msg_meta_dicts, + ) = self._get_sys_message_info( + assistant_role_name, + user_role_name, + sys_msg_generator, + extend_sys_msg_meta_dicts=extend_sys_msg_meta_dicts, + ) + else: + init_assistant_sys_msg = assistant_agent.system_message + init_user_sys_msg = user_agent.system_message + self._init_agents( init_assistant_sys_msg, init_user_sys_msg, @@ -165,6 +179,8 @@ def __init__( user_agent_kwargs=user_agent_kwargs, output_language=output_language, stop_event=stop_event, + assistant_agent=assistant_agent, + user_agent=user_agent, ) self.critic: Optional[Union[CriticAgent, Human]] = None self.critic_sys_msg: Optional[BaseMessage] = None @@ -326,6 +342,8 @@ def _init_agents( user_agent_kwargs: Optional[Dict] = None, output_language: Optional[str] = None, stop_event: Optional[threading.Event] = None, + assistant_agent: Optional[ChatAgent] = None, + user_agent: Optional[ChatAgent] = None, ) -> None: r"""Initialize assistant and user agents with their system messages. @@ -343,6 +361,12 @@ def _init_agents( stop_event (Optional[threading.Event], optional): Event to signal termination of the agent's operation. When set, the agent will terminate its execution. (default: :obj:`None`) + assistant_agent (ChatAgent, optional): A pre-configured ChatAgent to use as the + assistant. If provided, this will override the creation of a new assistant agent. + (default: :obj:`None`) + user_agent (ChatAgent, optional): A pre-configured ChatAgent to use as the + user. If provided, this will override the creation of a new user agent. + (default: :obj:`None`) """ if self.model is not None: if assistant_agent_kwargs is None: @@ -353,22 +377,36 @@ def _init_agents( user_agent_kwargs = {'model': self.model} elif 'model' not in user_agent_kwargs: user_agent_kwargs.update(dict(model=self.model)) + + # Use provided assistant agent if available, otherwise create a new one + if assistant_agent is not None: + # Ensure functionality consistent + assistant_agent.output_language = output_language + assistant_agent.stop_event = stop_event + self.assistant_agent = assistant_agent + + self.assistant_sys_msg = assistant_agent.system_message + else: + self.assistant_agent = ChatAgent( + init_assistant_sys_msg, + output_language=output_language, + stop_event=stop_event, + **(assistant_agent_kwargs or {}), + ) + self.assistant_sys_msg = self.assistant_agent.system_message - self.assistant_agent = ChatAgent( - init_assistant_sys_msg, - output_language=output_language, - stop_event=stop_event, - **(assistant_agent_kwargs or {}), - ) - self.assistant_sys_msg = self.assistant_agent.system_message - - self.user_agent = ChatAgent( - init_user_sys_msg, - output_language=output_language, - stop_event=stop_event, - **(user_agent_kwargs or {}), - ) - self.user_sys_msg = self.user_agent.system_message + # Use provided user agent if available, otherwise create a new one + if user_agent is not None: + self.user_agent = user_agent + self.user_sys_msg = user_agent.system_message + else: + self.user_agent = ChatAgent( + init_user_sys_msg, + output_language=output_language, + stop_event=stop_event, + **(user_agent_kwargs or {}), + ) + self.user_sys_msg = self.user_agent.system_message def _init_critic( self, @@ -727,4 +765,4 @@ def _is_multi_response(self, agent: ChatAgent) -> bool: and agent.model_backend.model_config_dict['n'] > 1 ): return True - return False + return False \ No newline at end of file diff --git a/docs/key_modules/societies.md b/docs/key_modules/societies.md index 8aa79b35b3..9312e8f021 100644 --- a/docs/key_modules/societies.md +++ b/docs/key_modules/societies.md @@ -70,6 +70,8 @@ icon: people-group extend_sys_msg_meta_dictsList[Dict]Extra metadata for system messages extend_task_specify_meta_dictDictExtra metadata for task specification output_languagestrTarget output language + assistant_agentChatAgentCustom ChatAgent to use as assistant (optional) + user_agentChatAgentCustom ChatAgent to use as user (optional)