diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 8df383bfb..9e22aa57f 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -4,10 +4,9 @@ import socket import time import traceback - from collections.abc import Iterable from datetime import datetime -from typing import TYPE_CHECKING, Any +from typing import Any from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse @@ -61,14 +60,9 @@ ) from memos.reranker.factory import RerankerFactory from memos.templates.instruction_completion import instruct_completion - - -if TYPE_CHECKING: - from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler -from memos.types import MOSSearchResult, UserContext +from memos.types import UserContext from memos.vec_dbs.factory import VecDBFactory - logger = get_logger(__name__) router = APIRouter(prefix="/product", tags=["Server API"]) @@ -81,10 +75,10 @@ def _to_iter(running: Any) -> Iterable: return [] if isinstance(running, dict): return running.values() - return running # assume it's already an iterable (e.g., list) + return running -def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: +def _build_graph_db_config(user_id: str = "default"): """Build graph database configuration.""" graph_db_backend_map = { "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), @@ -102,63 +96,6 @@ def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: ) -def _build_vec_db_config() -> dict[str, Any]: - """Build vector database configuration.""" - return VectorDBConfigFactory.model_validate( - { - "backend": "milvus", - "config": APIConfig.get_milvus_config(), - } - ) - - -def _build_llm_config() -> dict[str, Any]: - """Build LLM configuration.""" - return LLMConfigFactory.model_validate( - { - "backend": "openai", - "config": APIConfig.get_openai_config(), - } - ) - - -def _build_embedder_config() -> dict[str, Any]: - """Build embedder configuration.""" - return EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) - - -def _build_mem_reader_config() -> dict[str, Any]: - """Build memory reader configuration.""" - return MemReaderConfigFactory.model_validate( - APIConfig.get_product_default_config()["mem_reader"] - ) - - -def _build_reranker_config() -> dict[str, Any]: - """Build reranker configuration.""" - return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) - - -def _build_internet_retriever_config() -> dict[str, Any]: - """Build internet retriever configuration.""" - return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) - - -def _build_pref_extractor_config() -> dict[str, Any]: - """Build extractor configuration.""" - return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) - - -def _build_pref_adder_config() -> dict[str, Any]: - """Build adder configuration.""" - return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) - - -def _build_pref_retriever_config() -> dict[str, Any]: - """Build retriever configuration.""" - return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) - - def _get_default_memory_size(cube_config) -> dict[str, int]: """Get default memory size configuration.""" return getattr(cube_config.text_mem.config, "memory_size", None) or { @@ -175,15 +112,27 @@ def init_server(): # Build component configurations graph_db_config = _build_graph_db_config() - llm_config = _build_llm_config() - embedder_config = _build_embedder_config() - mem_reader_config = _build_mem_reader_config() - reranker_config = _build_reranker_config() - internet_retriever_config = _build_internet_retriever_config() - vector_db_config = _build_vec_db_config() - pref_extractor_config = _build_pref_extractor_config() - pref_adder_config = _build_pref_adder_config() - pref_retriever_config = _build_pref_retriever_config() + llm_config = LLMConfigFactory.model_validate( + {"backend": "openai", "config": APIConfig.get_openai_config()} + ) + embedder_config = EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) + mem_reader_config = MemReaderConfigFactory.model_validate( + APIConfig.get_product_default_config()["mem_reader"] + ) + reranker_config = RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + internet_retriever_config = InternetRetrieverConfigFactory.model_validate( + APIConfig.get_internet_config() + ) + vector_db_config = VectorDBConfigFactory.model_validate( + {"backend": "milvus", "config": APIConfig.get_milvus_config()} + ) + pref_extractor_config = ExtractorConfigFactory.model_validate( + {"backend": "naive", "config": {}} + ) + pref_adder_config = AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) + pref_retriever_config = RetrieverConfigFactory.model_validate( + {"backend": "naive", "config": {}} + ) # Create component instances graph_db = GraphStoreFactory.from_config(graph_db_config) @@ -249,7 +198,7 @@ def init_server(): scheduler_config = SchedulerConfigFactory( backend="optimized_scheduler", config=scheduler_config_dict ) - mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config) + mem_scheduler = SchedulerFactory.from_config(scheduler_config) mem_scheduler.initialize_modules( chat_llm=llm, process_llm=mem_reader.llm, @@ -320,97 +269,38 @@ def _format_memory_item(memory_data: Any) -> dict[str, Any]: return memory -def _post_process_pref_mem( - memories_result: list[dict[str, Any]], - pref_formatted_mem: list[dict[str, Any]], - mem_cube_id: str, - include_preference: bool, -): - if include_preference: - memories_result["pref_mem"].append( - { - "cube_id": mem_cube_id, - "memories": pref_formatted_mem, - } - ) - pref_instruction, pref_note = instruct_completion(pref_formatted_mem) - memories_result["pref_string"] = pref_instruction - memories_result["pref_note"] = pref_note - - return memories_result - - @router.post("/search", summary="Search memories", response_model=SearchResponse) def search_memories(search_req: APISearchRequest): """Search memories for a specific user.""" - # Create UserContext object - how to assign values - user_context = UserContext( - user_id=search_req.user_id, - mem_cube_id=search_req.mem_cube_id, - session_id=search_req.session_id or "default_session", - ) logger.info(f"Search Req is: {search_req}") - memories_result: MOSSearchResult = { - "text_mem": [], - "act_mem": [], - "para_mem": [], - "pref_mem": [], - "pref_note": "", - } - - search_mode = search_req.mode - - def _search_text(): - if search_mode == SearchMode.FAST: - formatted_memories = fast_search_memories( - search_req=search_req, user_context=user_context - ) - elif search_mode == SearchMode.FINE: - formatted_memories = fine_search_memories( - search_req=search_req, user_context=user_context - ) - elif search_mode == SearchMode.MIXTURE: - formatted_memories = mix_search_memories( - search_req=search_req, user_context=user_context - ) - else: - logger.error(f"Unsupported search mode: {search_mode}") - raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}") - return formatted_memories - - def _search_pref(): - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - results = naive_mem_cube.pref_mem.search( - query=search_req.query, - top_k=search_req.pref_top_k, - info={ - "user_id": search_req.user_id, - "session_id": search_req.session_id, - "chat_history": search_req.chat_history, - }, - ) - return [_format_memory_item(data) for data in results] with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(_search_text) - pref_future = executor.submit(_search_pref) + text_future = executor.submit(_search_text_mem, search_req) + pref_future = executor.submit(_search_pref_mem, search_req) text_formatted_memories = text_future.result() pref_formatted_memories = pref_future.result() - memories_result["text_mem"].append( - { - "cube_id": search_req.mem_cube_id, - "memories": text_formatted_memories, - } - ) + text_mem = [{"cube_id": search_req.mem_cube_id, "memories": text_formatted_memories}] + act_mem = [] + para_mem = [] + pref_mem = [] + if search_req.include_preference: + pref_instruction, pref_note = instruct_completion(pref_formatted_memories) + pref_mem = [ + { + "cube_id": search_req.mem_cube_id, + "memories": pref_formatted_memories, + "pref_note": pref_note, + "pref_string": pref_instruction, + } + ] - memories_result = _post_process_pref_mem( - memories_result, - pref_formatted_memories, - search_req.mem_cube_id, - search_req.include_preference, - ) + memories_result = { + "text_mem": text_mem, + "act_mem": act_mem, + "para_mem": para_mem, + "pref_mem": pref_mem, + } logger.info(f"Search memories result: {memories_result}") @@ -420,89 +310,65 @@ def _search_pref(): ) -def mix_search_memories( - search_req: APISearchRequest, - user_context: UserContext, -): - """ - Mix search memories: fast search + async fine search - """ - - formatted_memories = mem_scheduler.mix_search_memories( - search_req=search_req, - user_context=user_context, - ) - return formatted_memories - - -def fine_search_memories( - search_req: APISearchRequest, - user_context: UserContext, -): - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - # Create MemCube and perform search - search_results = naive_mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FINE, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, +def _search_text_mem(search_req: APISearchRequest): + search_mode = search_req.mode + user_context = UserContext( + user_id=search_req.user_id, + mem_cube_id=search_req.mem_cube_id, + session_id=search_req.session_id or "default_session", ) - formatted_memories = [_format_memory_item(data) for data in search_results] - + if search_mode in [SearchMode.FAST, SearchMode.FINE]: + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + search_results = naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=search_mode, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + formatted_memories = [_format_memory_item(data) for data in search_results] + elif search_mode == SearchMode.MIXTURE: + formatted_memories = mem_scheduler.mix_search_memories( + search_req=search_req, + user_context=user_context, + ) + else: + logger.error(f"Unsupported search mode: {search_mode}") + raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}") return formatted_memories -def fast_search_memories( - search_req: APISearchRequest, - user_context: UserContext, -): - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - # Create MemCube and perform search - search_results = naive_mem_cube.text_mem.search( +def _search_pref_mem(search_req): + if ( + os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true" + or not search_req.include_preference + ): + return [] + results = naive_mem_cube.pref_mem.search( query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, + top_k=search_req.pref_top_k, info={ "user_id": search_req.user_id, - "session_id": target_session_id, + "session_id": search_req.session_id, "chat_history": search_req.chat_history, }, ) - formatted_memories = [_format_memory_item(data) for data in search_results] - - return formatted_memories + return [_format_memory_item(data) for data in results] @router.post("/add", summary="Add memories", response_model=MemoryResponse) def add_memories(add_req: APIADDRequest): """Add memories for a specific user.""" - # Create UserContext object - how to assign values - user_context = UserContext( - user_id=add_req.user_id, - mem_cube_id=add_req.mem_cube_id, - session_id=add_req.session_id or "default_session", - ) - logger.info(f"Add Req is: {add_req}") target_session_id = add_req.session_id @@ -530,7 +396,7 @@ def _process_text_mem() -> list[dict[str, str]]: logger.info(f"Memory extraction completed for user {add_req.user_id}") mem_ids_local: list[str] = naive_mem_cube.text_mem.add( flattened_local, - user_name=user_context.mem_cube_id, + user_name=add_req.mem_cube_id, ) logger.info( f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " diff --git a/src/memos/mem_reader/memory.py b/src/memos/mem_reader/memory.py deleted file mode 100644 index 51440e888..000000000 --- a/src/memos/mem_reader/memory.py +++ /dev/null @@ -1,298 +0,0 @@ -from datetime import datetime -from typing import Any - -from memos.llms.base import BaseLLM - - -class Memory: - """Class representing the memory structure for storing and organizing memory content.""" - - def __init__( - self, - user_id: str, - session_id: str, - created_at: datetime, - ): - """ - Initialize the Memory structure. - - Args: - user_id: User identifier - session_id: Session identifier - created_at: Creation timestamp - """ - self.objective_memory: dict[str, dict[str, Any]] = {} - self.subjective_memory: dict[str, dict[str, Any]] = {} - self.scene_memory = { - "qa_pair": { - "section": [], - "info": { - "user_id": user_id, - "session_id": session_id, - "created_at": created_at, - "summary": "", - "label": [], - }, - }, - "document": { - "section": [], - "info": { - "user_id": user_id, - "session_id": session_id, - "created_at": created_at, - "doc_type": "", # pdf, txt, etc. - "doc_category": "", # research_paper, news, etc. - "doc_name": "", - "summary": "", - "label": [], - }, - }, - } - - def to_dict(self) -> dict[str, Any]: - """ - Convert the Memory object to a dictionary. - - Returns: - Dictionary representation of the Memory object - """ - return { - "objective_memory": self.objective_memory, - "subjective_memory": self.subjective_memory, - "scene_memory": self.scene_memory, - } - - def update_user_memory( - self, - memory_type: str, - key: str, - value: Any, - origin_data: str, - confidence_score: float = 1.0, - timestamp: str | None = None, - ) -> None: - """ - Update a memory item in either objective_memory or subjective_memory. - If a key already exists, the new memory item's info will replace the existing one, - and the values will be connected. - - Args: - memory_type: Type of memory to update ('objective' or 'subjective') - key: Key for the memory item. Must be one of: - - | Memory Type | Key | Description | - |-------------------|----------------------|---------------------------------------------------------| - | objective_memory | nickname | User's preferred name or alias | - | objective_memory | gender | User's gender (male, female, other) | - | objective_memory | personality | User's personality traits or MBTI type | - | objective_memory | birth | User's birthdate or age information | - | objective_memory | education | User's educational background | - | objective_memory | work | User's professional history | - | objective_memory | achievement | User's notable accomplishments | - | objective_memory | occupation | User's current job or role | - | objective_memory | residence | User's home location or living situation | - | objective_memory | location | User's current geographical location | - | objective_memory | income | User's financial information | - | objective_memory | preference | User's likes and dislikes | - | objective_memory | expertise | User's skills and knowledge areas | - | objective_memory | language | User's language proficiency | - | objective_memory | hobby | User's recreational activities | - | objective_memory | goal | User's long-term aspirations | - |-------------------|----------------------|---------------------------------------------------------| - | subjective_memory | current_mood | User's current emotional state | - | subjective_memory | response_style | User's preferred interaction style | - | subjective_memory | language_style | User's language patterns and preferences | - | subjective_memory | information_density | User's preference for detail level in responses | - | subjective_memory | interaction_pace | User's preferred conversation speed and frequency | - | subjective_memory | followed_topic | Topics the user is currently interested in | - | subjective_memory | current_goal | User's immediate objectives in the conversation | - | subjective_memory | content_type | User's preferred field of interest (e.g., technology, finance, etc.) | - | subjective_memory | role_preference | User's preferred assistant role (e.g., domain expert, translation assistant, etc.) | - - value: Value to store - origin_data: Original data that led to this memory - confidence_score: Confidence score (0.0 to 1.0) - timestamp: Timestamp string, if None current time will be used - """ - if timestamp is None: - timestamp = datetime.now() - - memory_item = { - "value": value, - "info": { - "timestamp": timestamp, - "confidence_score": confidence_score, - "origin_data": origin_data, - }, - } - - if memory_type == "objective": - memory_dict = self.objective_memory - elif memory_type == "subjective": - memory_dict = self.subjective_memory - else: - raise ValueError( - f"Invalid memory_type: {memory_type}. Must be 'objective' or 'subjective'." - ) - - # Check if key already exists - if key in memory_dict: - existing_item = memory_dict[key] - - # Connect the values (keep history but present as a connected string) - combined_value = f"{existing_item['value']} | {value}" - - # Update the memory item with combined value and new info (using the newest info) - memory_dict[key] = { - "value": combined_value, - "info": memory_item["info"], # Use the new info - } - else: - # If key doesn't exist, simply add the new memory item - memory_dict[key] = memory_item - - def add_qa_batch( - self, batch_summary: str, pair_summaries: list[dict], themes: list[str], order: int - ) -> None: - """ - Add a batch of Q&A pairs to the scene memory as a single subsection. - - Args: - batch_summary: The summary of the entire batch - pair_summaries: List of dictionaries, each containing: - - question: The summarized question for a single pair - - summary: The original dialogue for a single pair - - prompt: The prompt used for summarization - - time: The extracted time information (if any) - themes: List of themes associated with the batch - order: Order of the batch in the sequence - """ - qa_subsection = { - "subsection": {}, - "info": { - "summary": batch_summary, - "label": themes, - "origin_data": "", - "order": order, - }, - } - - for pair in pair_summaries: - qa_subsection["subsection"][pair["question"]] = { - "summary": pair["summary"], - "sources": pair["prompt"].split("\n\n", 1)[-1], - "time": pair.get("time", ""), # Add time field with default empty string - } - - self.scene_memory["qa_pair"]["section"].append(qa_subsection) - - def add_document_chunk_group( - self, summary: str, label: list[str], order: int, sub_chunks: list - ) -> None: - """ - Add a group of document chunks as a single section with multiple facts in the subsection. - - Args: - summary: The summary of the large chunk - label: List of theme labels for the large chunk - order: Order of the large chunk in the sequence - sub_chunks: List of dictionaries containing small chunks information, - each with keys: 'question', 'chunk_text', 'prompt' - """ - doc_section = { - "subsection": {}, - "info": { - "summary": summary, - "label": label, - "origin_data": "", - "order": order, - }, - } - - # Add each small chunk as a fact in the subsection - for sub_chunk in sub_chunks: - question = sub_chunk["question"] - doc_section["subsection"][question] = { - "summary": sub_chunk["chunk_text"], - "sources": sub_chunk["prompt"].split("\n\n", 1)[-1], - } - - self.scene_memory["document"]["section"].append(doc_section) - - def process_qa_pair_summaries(self, llm: BaseLLM | None = None) -> None: - """ - Process all qa_pair subsection summaries to generate a section summary. - - Args: - llm: Optional LLM instance to generate summary. If None, concatenates subsection summaries. - Returns: - The generated section summary - """ - all_summaries = [] - all_labels = set() - - # Collect all subsection summaries and labels - for section in self.scene_memory["qa_pair"]["section"]: - if "info" in section and "summary" in section["info"]: - all_summaries.append(section["info"]["summary"]) - if "info" in section and "label" in section["info"]: - all_labels.update(section["info"]["label"]) - - # Generate summary - if llm is not None: - # Use LLM to generate a coherent summary - all_summaries_str = "\n".join(all_summaries) - messages = [ - { - "role": "user", - "content": f"Summarize this text into a concise and objective sentence that captures its main idea. Provide only the required content directly, without including any additional information.\n\n{all_summaries_str}", - } - ] - section_summary = llm.generate(messages) - else: - # Simple concatenation of summaries - section_summary = " ".join(all_summaries) - - # Update the section info - self.scene_memory["qa_pair"]["info"]["summary"] = section_summary - self.scene_memory["qa_pair"]["info"]["label"] = list(all_labels) - - def process_document_summaries(self, llm=None) -> str: - """ - Process all document subsection summaries to generate a section summary. - - Args: - llm: Optional LLM instance to generate summary. If None, concatenates subsection summaries. - Returns: - The generated section summary - """ - all_summaries = [] - all_labels = set() - - # Collect all subsection summaries and labels - for section in self.scene_memory["document"]["section"]: - if "info" in section and "summary" in section["info"]: - all_summaries.append(section["info"]["summary"]) - if "info" in section and "label" in section["info"]: - all_labels.update(section["info"]["label"]) - - # Generate summary - if llm is not None: - # Use LLM to generate a coherent summary - all_summaries_str = "\n".join(all_summaries) - messages = [ - { - "role": "user", - "content": f"Summarize this text into a concise and objective sentence that captures its main idea. Provide only the required content directly, without including any additional information.\n\n{all_summaries_str}", - } - ] - section_summary = llm.generate(messages) - else: - # Simple concatenation of summaries - section_summary = " ".join(all_summaries) - - # Update the section info - self.scene_memory["document"]["info"]["summary"] = section_summary - self.scene_memory["document"]["info"]["label"] = list(all_labels) - - return section_summary diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 0c41717ea..06e3c2950 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -6,8 +6,13 @@ from datetime import datetime from memos.context.context import ContextThreadPoolExecutor +from memos.embedders.base import BaseEmbedder from memos.embedders.factory import OllamaEmbedder +from memos.graph_dbs.base import BaseGraphDB +from memos.graph_dbs.nebular import NebulaGraphDB from memos.graph_dbs.neo4j import Neo4jGraphDB +from memos.graph_dbs.polardb import PolarDBGraphDB +from memos.llms.base import BaseLLM from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata @@ -16,7 +21,6 @@ QueueMessage, ) - logger = get_logger(__name__) @@ -54,9 +58,9 @@ def extract_working_binding_ids(mem_items: list[TextualMemoryItem]) -> set[str]: class MemoryManager: def __init__( self, - graph_store: Neo4jGraphDB, - embedder: OllamaEmbedder, - llm: OpenAILLM | OllamaLLM | AzureLLM, + graph_store: BaseGraphDB | Neo4jGraphDB | NebulaGraphDB | PolarDBGraphDB, + embedder: BaseEmbedder | OllamaEmbedder, + llm: BaseLLM | OpenAILLM | OllamaLLM | AzureLLM, memory_size: dict | None = None, threshold: float | None = 0.80, merged_threshold: float | None = 0.92, diff --git a/tests/mem_reader/test_memory.py b/tests/mem_reader/test_memory.py deleted file mode 100644 index a0091adea..000000000 --- a/tests/mem_reader/test_memory.py +++ /dev/null @@ -1,246 +0,0 @@ -from datetime import datetime - -from memos.mem_reader.memory import Memory - - -def test_memory_initialization(): - """Test initialization of Memory class.""" - user_id = "user123" - session_id = "session456" - created_at = datetime.utcnow() - - memory = Memory(user_id=user_id, session_id=session_id, created_at=created_at) - - # Check initial empty structures - assert memory.objective_memory == {} - assert memory.subjective_memory == {} - assert "qa_pair" in memory.scene_memory - assert "document" in memory.scene_memory - - # Check info fields are correctly initialized - assert memory.scene_memory["qa_pair"]["info"]["user_id"] == user_id - assert memory.scene_memory["qa_pair"]["info"]["session_id"] == session_id - assert memory.scene_memory["qa_pair"]["info"]["created_at"] == created_at - assert memory.scene_memory["document"]["info"]["user_id"] == user_id - assert memory.scene_memory["document"]["info"]["session_id"] == session_id - assert memory.scene_memory["document"]["info"]["created_at"] == created_at - - -def test_to_dict(): - """Test conversion of Memory to dictionary.""" - memory = Memory(user_id="user123", session_id="session456", created_at=datetime.now()) - - memory_dict = memory.to_dict() - - assert "objective_memory" in memory_dict - assert "subjective_memory" in memory_dict - assert "scene_memory" in memory_dict - assert "qa_pair" in memory_dict["scene_memory"] - assert "document" in memory_dict["scene_memory"] - - -def test_add_qa_batch(): - """Test adding a batch of Q&A pairs to scene memory.""" - memory = Memory(user_id="user123", session_id="session456", created_at=datetime.now()) - - batch_summary = "Discussion about programming languages" - pair_summaries = [ - { - "question": "What is Python?", - "summary": "Python is a high-level programming language.", - "prompt": "Question\n\nOriginal conversation: User asked about Python and its features", - "time": "2023-01-01", - }, - { - "question": "What is Java?", - "summary": "Java is a class-based, object-oriented programming language.", - "prompt": "Question\n\nOriginal conversation: User inquired about Java programming", - }, - ] - themes = ["programming", "languages"] - order = 1 - - memory.add_qa_batch(batch_summary, pair_summaries, themes, order) - - # Check if the batch was added correctly - assert len(memory.scene_memory["qa_pair"]["section"]) == 1 - added_section = memory.scene_memory["qa_pair"]["section"][0] - - # Check section info - assert added_section["info"]["summary"] == batch_summary - assert added_section["info"]["label"] == themes - assert added_section["info"]["order"] == order - - # Check subsections (QA pairs) - assert "What is Python?" in added_section["subsection"] - assert "What is Java?" in added_section["subsection"] - - # Check specific QA pair content - python_qa = added_section["subsection"]["What is Python?"] - assert python_qa["summary"] == "Python is a high-level programming language." - assert "Original conversation: User asked about Python" in python_qa["sources"] - assert python_qa["time"] == "2023-01-01" - - # Check that time field defaults to empty string when not provided - java_qa = added_section["subsection"]["What is Java?"] - assert java_qa["time"] == "" - - -def test_add_document_chunk_group(): - """Test adding a document chunk group to scene memory.""" - memory = Memory(user_id="user123", session_id="session456", created_at=datetime.now()) - - summary = "Introduction to Machine Learning" - label = ["ML", "AI", "technology"] - order = 1 - sub_chunks = [ - { - "question": "What is supervised learning?", - "chunk_text": "Supervised learning is where the model learns from labeled training data.", - "prompt": "Extract key information\n\nOriginal text: Detailed explanation of supervised learning", - }, - { - "question": "What is unsupervised learning?", - "chunk_text": "Unsupervised learning is where the model learns patterns from unlabeled data.", - "prompt": "Extract key information\n\nOriginal text: Comprehensive overview of unsupervised learning", - }, - ] - - memory.add_document_chunk_group(summary, label, order, sub_chunks) - - # Check if the document chunk group was added correctly - assert len(memory.scene_memory["document"]["section"]) == 1 - added_section = memory.scene_memory["document"]["section"][0] - - # Check section info - assert added_section["info"]["summary"] == summary - assert added_section["info"]["label"] == label - assert added_section["info"]["order"] == order - - # Check subsections (document chunks) - assert "What is supervised learning?" in added_section["subsection"] - assert "What is unsupervised learning?" in added_section["subsection"] - - # Check specific document chunk content - supervised_chunk = added_section["subsection"]["What is supervised learning?"] - assert ( - supervised_chunk["summary"] - == "Supervised learning is where the model learns from labeled training data." - ) - assert ( - "Original text: Detailed explanation of supervised learning" in supervised_chunk["sources"] - ) - - -def test_process_qa_pair_summaries_without_llm(): - """Test processing QA pair summaries without an LLM.""" - memory = Memory(user_id="user123", session_id="session456", created_at=datetime.now()) - - # Add two batches of QA pairs - memory.add_qa_batch( - "Programming languages discussion", - [{"question": "Python?", "summary": "About Python", "prompt": "Q"}], - ["programming"], - 1, - ) - memory.add_qa_batch( - "Database systems overview", - [{"question": "SQL?", "summary": "About SQL", "prompt": "Q"}], - ["database", "programming"], - 2, - ) - - # Process summaries without LLM - memory.process_qa_pair_summaries() - - # Check if the section summary was generated correctly - section_info = memory.scene_memory["qa_pair"]["info"] - assert section_info["summary"] == "Programming languages discussion Database systems overview" - assert set(section_info["label"]) == {"programming", "database"} - - -def test_process_document_summaries_without_llm(): - """Test processing document summaries without an LLM.""" - memory = Memory(user_id="user123", session_id="session456", created_at=datetime.now()) - - # Add two document chunk groups - memory.add_document_chunk_group( - "Introduction to AI", - ["AI", "technology"], - 1, - [{"question": "What is AI?", "chunk_text": "AI definition", "prompt": "Extract"}], - ) - memory.add_document_chunk_group( - "Deep Learning Basics", - ["AI", "deep learning"], - 2, - [{"question": "Neural Networks?", "chunk_text": "NN explanation", "prompt": "Extract"}], - ) - - # Process summaries without LLM - summary = memory.process_document_summaries() - - # Check if the section summary was generated correctly - section_info = memory.scene_memory["document"]["info"] - assert section_info["summary"] == "Introduction to AI Deep Learning Basics" - assert summary == "Introduction to AI Deep Learning Basics" - assert set(section_info["label"]) == {"AI", "technology", "deep learning"} - - -def test_process_qa_pair_summaries_with_llm(): - """Test processing QA pair summaries with a mock LLM.""" - memory = Memory(user_id="user123", session_id="session456", created_at=datetime.now()) - - # Add a batch of QA pairs - memory.add_qa_batch( - "Programming languages discussion", - [{"question": "Python?", "summary": "About Python", "prompt": "Q"}], - ["programming"], - 1, - ) - - # Create a mock LLM - class MockLLM: - def generate(self, messages): - return "Summarized content about programming languages" - - mock_llm = MockLLM() - - # Process summaries with mock LLM - memory.process_qa_pair_summaries(llm=mock_llm) - - # Check if the section summary was generated correctly using the LLM - assert ( - memory.scene_memory["qa_pair"]["info"]["summary"] - == "Summarized content about programming languages" - ) - - -def test_process_document_summaries_with_llm(): - """Test processing document summaries with a mock LLM.""" - memory = Memory(user_id="user123", session_id="session456", created_at=datetime.now()) - - # Add a document chunk group - memory.add_document_chunk_group( - "Introduction to AI", - ["AI", "technology"], - 1, - [{"question": "What is AI?", "chunk_text": "AI definition", "prompt": "Extract"}], - ) - - # Create a mock LLM - class MockLLM: - def generate(self, messages): - return "Summarized content about artificial intelligence" - - mock_llm = MockLLM() - - # Process summaries with mock LLM - summary = memory.process_document_summaries(llm=mock_llm) - - # Check if the section summary was generated correctly using the LLM - assert ( - memory.scene_memory["document"]["info"]["summary"] - == "Summarized content about artificial intelligence" - ) - assert summary == "Summarized content about artificial intelligence"