diff --git a/mirix/agent/agent_wrapper.py b/mirix/agent/agent_wrapper.py index 711eb30..27a735d 100644 --- a/mirix/agent/agent_wrapper.py +++ b/mirix/agent/agent_wrapper.py @@ -95,7 +95,17 @@ def get_image_mime_type(image_path): class AgentWrapper(): - def __init__(self, agent_config_file, load_from=None): + def __init__( + self, + agent_config_file, + load_from=None, + user_id: Optional[str] = None, + org_id: Optional[str] = None, + debug: bool = False, + default_llm_config: Optional[LLMConfig] = None, + default_embedding_config: Optional[EmbeddingConfig] = None, + ): + # If load_from is specified, restore the database first before any agent initialization if load_from is not None: @@ -115,11 +125,13 @@ def __init__(self, agent_config_file, load_from=None): self.logger = logging.getLogger(f"Mirix.AgentWrapper.{self.agent_name}") self.logger.setLevel(logging.INFO) - self.client = create_client() - self.client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - # self.client.set_default_embedding_config(EmbeddingConfig.default_config("text-embedding-3-small")) - self.client.set_default_embedding_config(EmbeddingConfig.default_config("text-embedding-004")) - + self.client = create_client( + user_id=user_id, + org_id=org_id, + debug=debug, + default_llm_config=default_llm_config or LLMConfig.default_config("gpt-4o-mini"), + default_embedding_config=default_embedding_config or EmbeddingConfig.default_config("text-embedding-004"), + ) # Initialize agent states container self.agent_states = AgentStates() diff --git a/mirix/client/client.py b/mirix/client/client.py index 880eb12..2035351 100644 --- a/mirix/client/client.py +++ b/mirix/client/client.py @@ -44,10 +44,20 @@ from mirix.prompts import gpt_persona -def create_client(): - return LocalClient() - - +def create_client( + user_id: Optional[str] = None, + org_id: Optional[str] = None, + debug: bool = False, + default_llm_config: Optional[LLMConfig] = None, + default_embedding_config: Optional[EmbeddingConfig] = None, +): + return LocalClient( + user_id=user_id, + org_id=org_id, + debug=debug, + default_llm_config=default_llm_config, + default_embedding_config=default_embedding_config, + ) class AbstractClient(object): def __init__( self, diff --git a/mirix/functions/function_sets/memory_tools.py b/mirix/functions/function_sets/memory_tools.py index 2452555..0f2e552 100644 --- a/mirix/functions/function_sets/memory_tools.py +++ b/mirix/functions/function_sets/memory_tools.py @@ -415,7 +415,10 @@ def trigger_memory_update_with_instruction(self: "Agent", user_message: object, from mirix import create_client - client = create_client() + client = create_client( + user_id=self.user.id, + org_id=self.user.organization_id + ) agents = client.list_agents() # Validate that user_message is a dictionary @@ -474,7 +477,10 @@ def trigger_memory_update(self: "Agent", user_message: object, memory_types: Lis from mirix import create_client - client = create_client() + client = create_client( + user_id=self.user.id, + org_id=self.user.organization_id + ) agents = client.list_agents() # Validate that user_message is a dictionary