Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions camel/societies/role_playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ class RolePlaying:
stop_event (Optional[threading.Event], optional): Event to signal
termination of the agent's operation. When set, the agent will
terminate its execution. (default: :obj:`None`)
assistant_agent (ChatAgent, optional): A pre-configured ChatAgent to use as the
assistant. If provided, this will override the creation of a new assistant agent.
(default: :obj:`None`)
user_agent (ChatAgent, optional): A pre-configured ChatAgent to use as the
user. If provided, this will override the creation of a new user agent.
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -106,6 +112,8 @@ def __init__(
extend_task_specify_meta_dict: Optional[Dict] = None,
output_language: Optional[str] = None,
stop_event: Optional[threading.Event] = None,
assistant_agent: Optional[ChatAgent] = None,
user_agent: Optional[ChatAgent] = None,
) -> None:
if model is not None:
logger.warning(
Expand Down Expand Up @@ -165,6 +173,8 @@ def __init__(
user_agent_kwargs=user_agent_kwargs,
output_language=output_language,
stop_event=stop_event,
assistant_agent=assistant_agent,
user_agent=user_agent,
)
self.critic: Optional[Union[CriticAgent, Human]] = None
self.critic_sys_msg: Optional[BaseMessage] = None
Expand Down Expand Up @@ -326,6 +336,8 @@ def _init_agents(
user_agent_kwargs: Optional[Dict] = None,
output_language: Optional[str] = None,
stop_event: Optional[threading.Event] = None,
assistant_agent: Optional[ChatAgent] = None,
user_agent: Optional[ChatAgent] = None,
) -> None:
r"""Initialize assistant and user agents with their system messages.

Expand All @@ -343,6 +355,12 @@ def _init_agents(
stop_event (Optional[threading.Event], optional): Event to signal
termination of the agent's operation. When set, the agent will
terminate its execution. (default: :obj:`None`)
assistant_agent (ChatAgent, optional): A pre-configured ChatAgent to use as the
assistant. If provided, this will override the creation of a new assistant agent.
(default: :obj:`None`)
user_agent (ChatAgent, optional): A pre-configured ChatAgent to use as the
user. If provided, this will override the creation of a new user agent.
(default: :obj:`None`)
"""
if self.model is not None:
if assistant_agent_kwargs is None:
Expand All @@ -353,22 +371,32 @@ def _init_agents(
user_agent_kwargs = {'model': self.model}
elif 'model' not in user_agent_kwargs:
user_agent_kwargs.update(dict(model=self.model))

# Use provided assistant agent if available, otherwise create a new one
if assistant_agent is not None:
self.assistant_agent = assistant_agent
self.assistant_sys_msg = assistant_agent.system_message
else:
self.assistant_agent = ChatAgent(
init_assistant_sys_msg,
output_language=output_language,
stop_event=stop_event,
**(assistant_agent_kwargs or {}),
)
self.assistant_sys_msg = self.assistant_agent.system_message

self.assistant_agent = ChatAgent(
init_assistant_sys_msg,
output_language=output_language,
stop_event=stop_event,
**(assistant_agent_kwargs or {}),
)
self.assistant_sys_msg = self.assistant_agent.system_message

self.user_agent = ChatAgent(
init_user_sys_msg,
output_language=output_language,
stop_event=stop_event,
**(user_agent_kwargs or {}),
)
self.user_sys_msg = self.user_agent.system_message
# Use provided user agent if available, otherwise create a new one
if user_agent is not None:
self.user_agent = user_agent
self.user_sys_msg = user_agent.system_message
else:
self.user_agent = ChatAgent(
init_user_sys_msg,
output_language=output_language,
stop_event=stop_event,
**(user_agent_kwargs or {}),
)
self.user_sys_msg = self.user_agent.system_message

def _init_critic(
self,
Expand Down Expand Up @@ -727,4 +755,4 @@ def _is_multi_response(self, agent: ChatAgent) -> bool:
and agent.model_backend.model_config_dict['n'] > 1
):
return True
return False
return False
2 changes: 2 additions & 0 deletions docs/key_modules/societies.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ icon: people-group
<tr><td>extend_sys_msg_meta_dicts</td><td>List[Dict]</td><td>Extra metadata for system messages</td></tr>
<tr><td>extend_task_specify_meta_dict</td><td>Dict</td><td>Extra metadata for task specification</td></tr>
<tr><td>output_language</td><td>str</td><td>Target output language</td></tr>
<tr><td>assistant_agent</td><td>ChatAgent</td><td>Custom ChatAgent to use as assistant (optional)</td></tr>
<tr><td>user_agent</td><td>ChatAgent</td><td>Custom ChatAgent to use as user (optional)</td></tr>
</tbody>
</table>
</Accordion>
Expand Down
Loading