From f61a012dfa9114eea17cd857fdbe9b8086da30b2 Mon Sep 17 00:00:00 2001 From: Chris Mangum Date: Fri, 2 May 2025 16:11:02 -0700 Subject: [PATCH 1/3] Update skip_validation parameter to default to True in memory retrieval methods This commit modifies the `skip_validation` parameter in various memory retrieval methods across the `attribute.py`, `redis_im.py`, `redis_stm.py`, and `sqlite_ltm.py` files to default to `True`. This change aims to streamline memory access by ensuring validation is skipped by default, enhancing performance and usability. --- memory/search/strategies/attribute.py | 4 ++-- memory/storage/redis_im.py | 8 ++++---- memory/storage/redis_stm.py | 8 ++++---- memory/storage/sqlite_ltm.py | 14 +++++++------- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/memory/search/strategies/attribute.py b/memory/search/strategies/attribute.py index 04ff0af..9b8d854 100644 --- a/memory/search/strategies/attribute.py +++ b/memory/search/strategies/attribute.py @@ -33,7 +33,7 @@ def __init__( im_store: RedisIMStore, ltm_store: SQLiteLTMStore, scoring_method: str = "length_ratio", - skip_validation: bool = False, + skip_validation: bool = True, ): """Initialize the attribute search strategy. @@ -146,7 +146,7 @@ def search( case_sensitive: bool = False, use_regex: bool = False, scoring_method: Optional[str] = None, - skip_validation: Optional[bool] = None, + skip_validation: Optional[bool] = True, **kwargs, ) -> List[Dict[str, Any]]: """Search for memories based on content and metadata attributes. diff --git a/memory/storage/redis_im.py b/memory/storage/redis_im.py index f8a26fd..fe7f7d2 100644 --- a/memory/storage/redis_im.py +++ b/memory/storage/redis_im.py @@ -480,7 +480,7 @@ def _store_memory_entry(self, agent_id: str, memory_entry: MemoryEntry) -> bool: ) return False - def get(self, agent_id: str, memory_id: str, skip_validation: bool = False) -> Optional[MemoryEntry]: + def get(self, agent_id: str, memory_id: str, skip_validation: bool = True) -> Optional[MemoryEntry]: """Retrieve a memory entry by ID. Args: @@ -616,7 +616,7 @@ def _hash_to_memory_entry(self, hash_data: Dict[str, Any]) -> MemoryEntry: return memory_entry def get_by_timerange( - self, agent_id: str, start_time: float, end_time: float, limit: int = 100, skip_validation: bool = False + self, agent_id: str, start_time: float, end_time: float, limit: int = 100, skip_validation: bool = True ) -> List[MemoryEntry]: """Retrieve memories within a time range. @@ -713,7 +713,7 @@ def get_by_importance( min_importance: float = 0.0, max_importance: float = 1.0, limit: int = 100, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[MemoryEntry]: """Retrieve memories by importance score range. @@ -1341,7 +1341,7 @@ def get_size(self, agent_id: str) -> int: logger.exception("Error retrieving memory size for agent %s", agent_id) return 0 - def get_all(self, agent_id: str, limit: int = 1000, skip_validation: bool = False) -> List[MemoryEntry]: + def get_all(self, agent_id: str, limit: int = 1000, skip_validation: bool = True) -> List[MemoryEntry]: """Get all memories for an agent. Args: diff --git a/memory/storage/redis_stm.py b/memory/storage/redis_stm.py index 43aa4d6..fff35b3 100644 --- a/memory/storage/redis_stm.py +++ b/memory/storage/redis_stm.py @@ -273,7 +273,7 @@ def _store_memory_entry(self, agent_id: str, memory_entry: MemoryEntry) -> bool: ) return False - def get(self, agent_id: str, memory_id: str, skip_validation: bool = False) -> Optional[MemoryEntry]: + def get(self, agent_id: str, memory_id: str, skip_validation: bool = True) -> Optional[MemoryEntry]: """Retrieve a memory entry by ID. Args: @@ -378,7 +378,7 @@ def _update_access_metadata( ) def get_by_timerange( - self, agent_id: str, start_time: float, end_time: float, limit: int = 100, skip_validation: bool = False + self, agent_id: str, start_time: float, end_time: float, limit: int = 100, skip_validation: bool = True ) -> List[MemoryEntry]: """Retrieve memories within a time range. @@ -425,7 +425,7 @@ def get_by_importance( min_importance: float = 0.0, max_importance: float = 1.0, limit: int = 100, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[MemoryEntry]: """Retrieve memories by importance score. @@ -641,7 +641,7 @@ def get_size(self, agent_id: str) -> int: logger.error("Error calculating memory size: %s", e) return 0 - def get_all(self, agent_id: str, limit: int = 1000, skip_validation: bool = False) -> List[MemoryEntry]: + def get_all(self, agent_id: str, limit: int = 1000, skip_validation: bool = True) -> List[MemoryEntry]: """Get all memories for an agent. Args: diff --git a/memory/storage/sqlite_ltm.py b/memory/storage/sqlite_ltm.py index 81b588a..1f661fb 100644 --- a/memory/storage/sqlite_ltm.py +++ b/memory/storage/sqlite_ltm.py @@ -555,7 +555,7 @@ def store_batch(self, memory_entries: List[Dict[str, Any]]) -> bool: logger.error("Unexpected error storing batch of memories: %s", str(e)) return False - def get(self, memory_id: str, agent_id: str, skip_validation: bool = False) -> Optional[Dict[str, Any]]: + def get(self, memory_id: str, agent_id: str, skip_validation: bool = True) -> Optional[Dict[str, Any]]: """Retrieve a memory by ID. Args: @@ -711,7 +711,7 @@ def get_by_timerange( end_time: Union[float, int, str], agent_id: str = None, limit: int = 100, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[Dict[str, Any]]: """Retrieve memories within a time range. @@ -800,7 +800,7 @@ def get_by_importance( min_importance: float = 0.0, max_importance: float = 1.0, limit: int = 100, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[Dict[str, Any]]: """Retrieve memories by importance score. @@ -858,7 +858,7 @@ def get_most_similar( query_vector: List[float], top_k: int = 10, agent_id: str = None, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[Tuple[Dict[str, Any], float]]: """Retrieve memories most similar to the query vector. @@ -940,7 +940,7 @@ def search_similar( k: int = 5, memory_type: Optional[str] = None, agent_id: str = None, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[Dict[str, Any]]: """Search for memories with similar embeddings. @@ -1161,7 +1161,7 @@ def get_size(self) -> int: logger.error("Unexpected error calculating memory size: %s", str(e)) return 0 - def get_all(self, agent_id: str = None, limit: int = 1000, skip_validation: bool = False) -> List[Dict[str, Any]]: + def get_all(self, agent_id: str = None, limit: int = 1000, skip_validation: bool = True) -> List[Dict[str, Any]]: """Get all memories for the agent. Args: @@ -1342,7 +1342,7 @@ def search_by_step_range( start_step: int, end_step: int, memory_type: Optional[str] = None, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[Dict[str, Any]]: """Search for memories within a specific step range. From 7f6c3fed43127a2c07b60a15b3a992431874a2ac Mon Sep 17 00:00:00 2001 From: Chris Mangum Date: Fri, 2 May 2025 16:21:45 -0700 Subject: [PATCH 2/3] Add README and unit tests for memory API This commit introduces a new README.md file for the `memory/api` module, detailing the Agent Memory System's architecture, components, and usage examples. Additionally, it adds comprehensive unit tests for the data models and types used in the memory API, ensuring robust functionality and correctness. The tests cover initialization, method behavior, and structure compliance for various memory-related classes and types. --- memory/api/README.md | 185 ++++++++++++++++++ tests/api/test_models.py | 236 +++++++++++++++++++++++ tests/api/test_types.py | 397 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 818 insertions(+) create mode 100644 memory/api/README.md create mode 100644 tests/api/test_models.py create mode 100644 tests/api/test_types.py diff --git a/memory/api/README.md b/memory/api/README.md new file mode 100644 index 0000000..3414e89 --- /dev/null +++ b/memory/api/README.md @@ -0,0 +1,185 @@ +# Agent Memory API + +The `memory/api` module provides a comprehensive interface for integrating the Agent Memory System with AI agents, enabling efficient storage, retrieval, and utilization of memories to support context-aware agent reasoning and behavior. + +## Overview + +The Agent Memory System implements a tiered memory architecture inspired by human cognitive systems: + +- **Short-Term Memory (STM)**: Recent, high-fidelity memories with detailed information +- **Intermediate Memory (IM)**: Medium-term memories with moderate compression +- **Long-Term Memory (LTM)**: Persistent, compressed memories retaining core information + +This architecture allows agents to efficiently manage memories with different levels of detail and importance across varying time horizons. + +## Module Components + +### Main API Class + +[`memory_api.py`](memory_api.py) - Provides the primary interface class `AgentMemoryAPI` for interacting with the memory system: +- Store agent states, actions, and interactions +- Retrieve memories by various criteria (ID, time range, attributes) +- Perform semantic search across memory tiers +- Manage memory lifecycle and maintenance + +### Agent Integration + +[`hooks.py`](hooks.py) - Offers decorators and utility functions for automatic memory integration: +- `install_memory_hooks`: Class decorator to add memory capabilities to agent classes +- `with_memory`: Instance decorator for adding memory to existing agent instances +- `BaseAgent`: Minimal interface with standard lifecycle methods for memory-aware agents + +### Data Models + +[`models.py`](models.py) - Defines structured representations of agent data: +- `AgentState`: Standardized representation of an agent's state +- `ActionData`: Record of an agent action with associated states and metrics +- `ActionResult`: Lightweight result of an action execution + +### Type Definitions + +[`types.py`](types.py) - Establishes core type definitions for the memory system: +- Memory entry structures (metadata, embeddings, content) +- Memory tiers and filtering types +- Statistics and query result types +- Protocol definitions for memory stores + +## Getting Started + +### Basic Usage + +```python +from memory.api import AgentMemoryAPI + +# Initialize the memory API +memory_api = AgentMemoryAPI() + +# Store an agent state +state_data = { + "agent_id": "agent-001", + "step_number": 42, + "content": { + "observation": "User asked about weather", + "thought": "I should check the forecast" + } +} +memory_id = memory_api.store_agent_state("agent-001", state_data, step_number=42) + +# Retrieve similar memories +query = "weather forecast" +similar_memories = memory_api.search_by_content("agent-001", query, k=3) + +# Use memories to inform agent's response +for memory in similar_memories: + print(f"Related memory: {memory['contents']}") +``` + +### Automatic Memory Integration + +```python +from memory.api import install_memory_hooks, BaseAgent + +@install_memory_hooks +class MyAgent(BaseAgent): + def __init__(self, config=None, agent_id=None): + super().__init__(config, agent_id) + # Agent-specific initialization + + def act(self, observation): + # Memory hooks automatically capture state before this method + self.step_number += 1 + action_result = self._process(observation) + # Memory hooks automatically capture state after this method + return action_result + + def get_state(self): + # Return current agent state + state = super().get_state() + state.extra_data["custom_field"] = self.some_internal_state + return state + +# Create an agent with memory enabled +agent = MyAgent(agent_id="agent-001") + +# Use the agent normally - memories are created automatically +result = agent.act({"user_input": "What's the weather today?"}) +``` + +## Advanced Features + +### Memory Maintenance + +```python +# Run memory maintenance to consolidate and optimize memories +memory_api.force_memory_maintenance("agent-001") + +# Get memory statistics +stats = memory_api.get_memory_statistics("agent-001") +print(f"Total memories: {stats['total_memories']}") +print(f"STM: {stats['stm_count']}, IM: {stats['im_count']}, LTM: {stats['ltm_count']}") +``` + +### Memory Search + +```python +# Search by content similarity +similar_memories = memory_api.search_by_content( + agent_id="agent-001", + content_query="user asked about calendar appointments", + k=5 +) + +# Retrieve memories by time range +recent_memories = memory_api.retrieve_by_time_range( + agent_id="agent-001", + start_step=100, + end_step=120 +) + +# Retrieve memories by attributes +filtered_memories = memory_api.retrieve_by_attributes( + agent_id="agent-001", + attributes={"action_type": "calendar_query"} +) +``` + +### Memory Configuration + +```python +from memory.api import AgentMemoryAPI +from memory.config import MemoryConfig + +# Custom configuration +config = MemoryConfig( + stm_config={"memory_limit": 1000}, + im_config={"memory_limit": 10000}, + ltm_config={"memory_limit": 100000} +) + +# Initialize API with custom configuration +memory_api = AgentMemoryAPI(config) + +# Update configuration +memory_api.configure_memory_system({ + "stm_config": {"memory_limit": 2000} +}) +``` + +## Error Handling + +```python +from memory.api import AgentMemoryAPI +from memory.api.memory_api import MemoryStoreException, MemoryRetrievalException + +memory_api = AgentMemoryAPI() + +try: + memory = memory_api.retrieve_state_by_id("agent-001", "non_existent_id") +except MemoryRetrievalException as e: + print(f"Memory retrieval error: {e}") + +try: + memories = memory_api.search_by_content("agent-001", "query", k=-1) +except MemoryConfigException as e: + print(f"Configuration error: {e}") +``` \ No newline at end of file diff --git a/tests/api/test_models.py b/tests/api/test_models.py new file mode 100644 index 0000000..be8cbd4 --- /dev/null +++ b/tests/api/test_models.py @@ -0,0 +1,236 @@ +"""Unit tests for memory API models.""" + +import pytest +from memory.api.models import AgentState, ActionData, ActionResult + + +class TestAgentState: + """Test suite for the AgentState class.""" + + def test_initialization(self): + """Test basic initialization of AgentState.""" + state = AgentState( + agent_id="agent-1", + step_number=42, + health=0.8, + reward=10.5, + position_x=1.0, + position_y=2.0, + position_z=3.0, + resource_level=0.7, + extra_data={"inventory": ["sword", "shield"]} + ) + + assert state.agent_id == "agent-1" + assert state.step_number == 42 + assert state.health == 0.8 + assert state.reward == 10.5 + assert state.position_x == 1.0 + assert state.position_y == 2.0 + assert state.position_z == 3.0 + assert state.resource_level == 0.7 + assert state.extra_data == {"inventory": ["sword", "shield"]} + + def test_initialization_minimal(self): + """Test initialization with only required fields.""" + state = AgentState(agent_id="agent-1", step_number=10) + + assert state.agent_id == "agent-1" + assert state.step_number == 10 + assert state.health is None + assert state.reward is None + assert state.position_x is None + assert state.position_y is None + assert state.position_z is None + assert state.resource_level is None + assert state.extra_data == {} + + def test_as_dict_with_all_fields(self): + """Test as_dict method with all fields populated.""" + state = AgentState( + agent_id="agent-1", + step_number=42, + health=0.8, + reward=10.5, + position_x=1.0, + position_y=2.0, + position_z=3.0, + resource_level=0.7, + extra_data={"inventory": ["sword", "shield"]} + ) + + state_dict = state.as_dict() + + assert state_dict["agent_id"] == "agent-1" + assert state_dict["step_number"] == 42 + assert state_dict["health"] == 0.8 + assert state_dict["reward"] == 10.5 + assert state_dict["position_x"] == 1.0 + assert state_dict["position_y"] == 2.0 + assert state_dict["position_z"] == 3.0 + assert state_dict["resource_level"] == 0.7 + assert state_dict["extra_data"] == {"inventory": ["sword", "shield"]} + + def test_as_dict_with_none_values(self): + """Test as_dict method excludes None values.""" + state = AgentState( + agent_id="agent-1", + step_number=42, + health=None, + reward=None, + position_x=None, + position_y=None, + position_z=None, + resource_level=None + ) + + state_dict = state.as_dict() + + assert "agent_id" in state_dict + assert "step_number" in state_dict + assert "health" not in state_dict + assert "reward" not in state_dict + assert "position_x" not in state_dict + assert "position_y" not in state_dict + assert "position_z" not in state_dict + assert "resource_level" not in state_dict + assert "extra_data" not in state_dict + + def test_as_dict_with_empty_extra_data(self): + """Test as_dict method excludes empty extra_data.""" + state = AgentState( + agent_id="agent-1", + step_number=42, + health=0.8, + extra_data={} + ) + + state_dict = state.as_dict() + + assert "agent_id" in state_dict + assert "step_number" in state_dict + assert "health" in state_dict + assert "extra_data" not in state_dict + + +class TestActionData: + """Test suite for the ActionData class.""" + + def test_initialization(self): + """Test basic initialization of ActionData.""" + action_data = ActionData( + action_type="move", + action_params={"direction": "north", "distance": 2}, + state_before={"position": [0, 0], "health": 1.0}, + state_after={"position": [0, 2], "health": 0.9}, + reward=5.0, + execution_time=0.25, + step_number=42 + ) + + assert action_data.action_type == "move" + assert action_data.action_params == {"direction": "north", "distance": 2} + assert action_data.state_before == {"position": [0, 0], "health": 1.0} + assert action_data.state_after == {"position": [0, 2], "health": 0.9} + assert action_data.reward == 5.0 + assert action_data.execution_time == 0.25 + assert action_data.step_number == 42 + + def test_initialization_default_values(self): + """Test initialization with default values.""" + action_data = ActionData( + action_type="move", + state_before={"position": [0, 0], "health": 1.0}, + state_after={"position": [0, 2], "health": 0.9}, + execution_time=0.25, + step_number=42 + ) + + assert action_data.action_type == "move" + assert action_data.action_params == {} + assert action_data.reward == 0.0 + + def test_get_state_difference_numeric(self): + """Test get_state_difference method with numeric values.""" + action_data = ActionData( + action_type="move", + state_before={"health": 1.0, "position_x": 10, "energy": 100}, + state_after={"health": 0.8, "position_x": 15, "energy": 90}, + execution_time=0.25, + step_number=42 + ) + + diff = action_data.get_state_difference() + + assert "health" in diff + assert "position_x" in diff + assert "energy" in diff + assert pytest.approx(diff["health"]) == -0.2 + assert diff["position_x"] == 5 + assert diff["energy"] == -10 + + def test_get_state_difference_non_numeric(self): + """Test get_state_difference method ignores non-numeric values.""" + action_data = ActionData( + action_type="move", + state_before={ + "health": 1.0, + "position": [0, 0], + "inventory": ["sword"] + }, + state_after={ + "health": 0.8, + "position": [0, 2], + "inventory": ["sword", "shield"] + }, + execution_time=0.25, + step_number=42 + ) + + diff = action_data.get_state_difference() + + assert "health" in diff + assert pytest.approx(diff["health"]) == -0.2 + assert "position" not in diff + assert "inventory" not in diff + + def test_get_state_difference_missing_keys(self): + """Test get_state_difference method with keys present in only one state.""" + action_data = ActionData( + action_type="move", + state_before={"health": 1.0, "energy": 100}, + state_after={"health": 0.8, "mana": 50}, + execution_time=0.25, + step_number=42 + ) + + diff = action_data.get_state_difference() + + assert "health" in diff + assert pytest.approx(diff["health"]) == -0.2 + assert "energy" not in diff + assert "mana" not in diff + + +class TestActionResult: + """Test suite for the ActionResult class.""" + + def test_initialization(self): + """Test basic initialization of ActionResult.""" + result = ActionResult( + action_type="attack", + params={"target": "enemy-1", "weapon": "sword"}, + reward=10.0 + ) + + assert result.action_type == "attack" + assert result.params == {"target": "enemy-1", "weapon": "sword"} + assert result.reward == 10.0 + + def test_initialization_default_values(self): + """Test initialization with default values.""" + result = ActionResult(action_type="observe") + + assert result.action_type == "observe" + assert result.params == {} + assert result.reward == 0.0 \ No newline at end of file diff --git a/tests/api/test_types.py b/tests/api/test_types.py new file mode 100644 index 0000000..bb4e3f2 --- /dev/null +++ b/tests/api/test_types.py @@ -0,0 +1,397 @@ +"""Unit tests for memory API types.""" + +import pytest +from typing import Dict, List, Any, Optional +from memory.api.types import ( + MemoryMetadata, + MemoryEmbeddings, + MemoryEntry, + MemoryChangeRecord, + MemoryTypeDistribution, + MemoryStatistics, + SimilaritySearchResult, + ConfigUpdate, + QueryResult, + MemoryStore +) + + +class TestMemoryTypesStructures: + """Test suite for memory type structures.""" + + def test_memory_metadata_structure(self): + """Test MemoryMetadata structure with all fields.""" + metadata: MemoryMetadata = { + "creation_time": 1649879872.123, + "last_access_time": 1649879900.456, + "compression_level": 0, + "importance_score": 0.8, + "retrieval_count": 5, + "memory_type": "state", + "current_tier": "stm", + "checksum": "abc123" + } + + assert metadata["creation_time"] == 1649879872.123 + assert metadata["last_access_time"] == 1649879900.456 + assert metadata["compression_level"] == 0 + assert metadata["importance_score"] == 0.8 + assert metadata["retrieval_count"] == 5 + assert metadata["memory_type"] == "state" + assert metadata["current_tier"] == "stm" + assert metadata["checksum"] == "abc123" + + def test_memory_metadata_partial(self): + """Test MemoryMetadata with partial fields (total=False).""" + # This should work because MemoryMetadata has total=False + metadata: MemoryMetadata = { + "creation_time": 1649879872.123, + "importance_score": 0.8, + "memory_type": "state", + } + + assert metadata["creation_time"] == 1649879872.123 + assert metadata["importance_score"] == 0.8 + assert metadata["memory_type"] == "state" + assert "compression_level" not in metadata + + def test_memory_embeddings_structure(self): + """Test MemoryEmbeddings structure with all fields.""" + embeddings: MemoryEmbeddings = { + "full_vector": [0.1, 0.2, 0.3, 0.4], + "compressed_vector": [0.15, 0.25, 0.35], + "abstract_vector": [0.2, 0.3] + } + + assert embeddings["full_vector"] == [0.1, 0.2, 0.3, 0.4] + assert embeddings["compressed_vector"] == [0.15, 0.25, 0.35] + assert embeddings["abstract_vector"] == [0.2, 0.3] + + def test_memory_embeddings_partial(self): + """Test MemoryEmbeddings with partial fields (total=False).""" + # This should work because MemoryEmbeddings has total=False + embeddings: MemoryEmbeddings = { + "full_vector": [0.1, 0.2, 0.3, 0.4], + } + + assert embeddings["full_vector"] == [0.1, 0.2, 0.3, 0.4] + assert "compressed_vector" not in embeddings + assert "abstract_vector" not in embeddings + + def test_memory_entry_structure(self): + """Test MemoryEntry structure with all fields.""" + memory: MemoryEntry = { + "memory_id": "mem_12345", + "agent_id": "agent_001", + "step_number": 42, + "timestamp": 1649879872.123, + "contents": {"observation": "User asked about weather"}, + "metadata": { + "importance_score": 0.8, + "memory_type": "interaction", + "creation_time": 1649879872.123, + "retrieval_count": 0, + "compression_level": 0, + "current_tier": "stm", + "last_access_time": 1649879872.123, + "checksum": "abc123" + }, + "embeddings": { + "full_vector": [0.1, 0.2, 0.3, 0.4] + } + } + + assert memory["memory_id"] == "mem_12345" + assert memory["agent_id"] == "agent_001" + assert memory["step_number"] == 42 + assert memory["timestamp"] == 1649879872.123 + assert memory["contents"]["observation"] == "User asked about weather" + assert memory["metadata"]["importance_score"] == 0.8 + assert memory["embeddings"]["full_vector"] == [0.1, 0.2, 0.3, 0.4] + + def test_memory_entry_without_embeddings(self): + """Test MemoryEntry without embeddings (embeddings is Optional).""" + memory: MemoryEntry = { + "memory_id": "mem_12345", + "agent_id": "agent_001", + "step_number": 42, + "timestamp": 1649879872.123, + "contents": {"observation": "User asked about weather"}, + "metadata": { + "importance_score": 0.8, + "memory_type": "interaction", + "creation_time": 1649879872.123, + "retrieval_count": 0, + "compression_level": 0, + "current_tier": "stm", + "last_access_time": 1649879872.123, + "checksum": "abc123" + }, + "embeddings": None + } + + assert memory["memory_id"] == "mem_12345" + assert memory["embeddings"] is None + + def test_memory_change_record_structure(self): + """Test MemoryChangeRecord structure.""" + change_record: MemoryChangeRecord = { + "memory_id": "mem_12345", + "step_number": 42, + "timestamp": 1649879872.123, + "previous_value": 10, + "new_value": 20 + } + + assert change_record["memory_id"] == "mem_12345" + assert change_record["step_number"] == 42 + assert change_record["timestamp"] == 1649879872.123 + assert change_record["previous_value"] == 10 + assert change_record["new_value"] == 20 + + def test_memory_change_record_with_none_previous(self): + """Test MemoryChangeRecord with None previous value.""" + change_record: MemoryChangeRecord = { + "memory_id": "mem_12345", + "step_number": 42, + "timestamp": 1649879872.123, + "previous_value": None, + "new_value": 20 + } + + assert change_record["memory_id"] == "mem_12345" + assert change_record["previous_value"] is None + assert change_record["new_value"] == 20 + + def test_memory_type_distribution_structure(self): + """Test MemoryTypeDistribution structure with all fields.""" + distribution: MemoryTypeDistribution = { + "state": 10, + "action": 20, + "interaction": 30 + } + + assert distribution["state"] == 10 + assert distribution["action"] == 20 + assert distribution["interaction"] == 30 + + def test_memory_type_distribution_partial(self): + """Test MemoryTypeDistribution with partial fields (total=False).""" + # This should work because MemoryTypeDistribution has total=False + distribution: MemoryTypeDistribution = { + "state": 10, + "action": 20, + } + + assert distribution["state"] == 10 + assert distribution["action"] == 20 + assert "interaction" not in distribution + + def test_memory_statistics_structure(self): + """Test MemoryStatistics structure.""" + stats: MemoryStatistics = { + "total_memories": 100, + "stm_count": 20, + "im_count": 30, + "ltm_count": 50, + "memory_type_distribution": { + "state": 40, + "action": 30, + "interaction": 30 + }, + "last_maintenance_time": 1649879872.123, + "insert_count_since_maintenance": 5 + } + + assert stats["total_memories"] == 100 + assert stats["stm_count"] == 20 + assert stats["im_count"] == 30 + assert stats["ltm_count"] == 50 + assert stats["memory_type_distribution"]["state"] == 40 + assert stats["last_maintenance_time"] == 1649879872.123 + assert stats["insert_count_since_maintenance"] == 5 + + def test_memory_statistics_with_none_time(self): + """Test MemoryStatistics with None for last_maintenance_time.""" + stats: MemoryStatistics = { + "total_memories": 100, + "stm_count": 20, + "im_count": 30, + "ltm_count": 50, + "memory_type_distribution": { + "state": 40, + "action": 30, + "interaction": 30 + }, + "last_maintenance_time": None, + "insert_count_since_maintenance": 5 + } + + assert stats["total_memories"] == 100 + assert stats["last_maintenance_time"] is None + + def test_similarity_search_result_structure(self): + """Test SimilaritySearchResult structure.""" + search_result: SimilaritySearchResult = { + "memory_id": "mem_12345", + "agent_id": "agent_001", + "step_number": 42, + "timestamp": 1649879872.123, + "contents": {"observation": "User asked about weather"}, + "metadata": { + "importance_score": 0.8, + "memory_type": "interaction", + "creation_time": 1649879872.123, + "retrieval_count": 0, + "compression_level": 0, + "current_tier": "stm", + "last_access_time": 1649879872.123, + "checksum": "abc123" + }, + "embeddings": { + "full_vector": [0.1, 0.2, 0.3, 0.4] + }, + "_similarity_score": 0.95 + } + + assert search_result["memory_id"] == "mem_12345" + assert search_result["contents"]["observation"] == "User asked about weather" + assert search_result["_similarity_score"] == 0.95 + + def test_config_update_structure(self): + """Test ConfigUpdate structure (Dict[str, Any]).""" + config_update: ConfigUpdate = { + "stm_config": {"memory_limit": 1000}, + "im_config": {"memory_limit": 5000}, + "enable_memory_hooks": True + } + + assert config_update["stm_config"]["memory_limit"] == 1000 + assert config_update["im_config"]["memory_limit"] == 5000 + assert config_update["enable_memory_hooks"] is True + + def test_query_result_structure(self): + """Test QueryResult structure.""" + query_result: QueryResult = { + "memory_id": "mem_12345", + "agent_id": "agent_001", + "step_number": 42, + "timestamp": 1649879872.123, + "contents": {"observation": "User asked about weather"}, + "metadata": { + "importance_score": 0.8, + "memory_type": "interaction", + "creation_time": 1649879872.123, + "retrieval_count": 0, + "compression_level": 0, + "current_tier": "stm", + "last_access_time": 1649879872.123, + "checksum": "abc123" + } + } + + assert query_result["memory_id"] == "mem_12345" + assert query_result["agent_id"] == "agent_001" + assert query_result["step_number"] == 42 + assert query_result["contents"]["observation"] == "User asked about weather" + assert query_result["metadata"]["importance_score"] == 0.8 + + +class TestMemoryStoreProtocol: + """Test suite for MemoryStore protocol implementation.""" + + class MockMemoryStore: + """Mock implementation of MemoryStore protocol for testing.""" + + def __init__(self): + self.memories = {} + + def get(self, memory_id: str) -> Optional[MemoryEntry]: + """Get a memory by ID.""" + return self.memories.get(memory_id) + + def get_recent(self, count: int, memory_type: Optional[str] = None) -> List[MemoryEntry]: + """Get recent memories.""" + return [] + + def get_by_step_range( + self, start_step: int, end_step: int, memory_type: Optional[str] = None + ) -> List[MemoryEntry]: + """Get memories in a step range.""" + return [] + + def get_by_attributes( + self, attributes: Dict[str, Any], memory_type: Optional[str] = None + ) -> List[MemoryEntry]: + """Get memories matching attributes.""" + return [] + + def search_by_vector( + self, vector: List[float], k: int = 5, memory_type: Optional[str] = None + ) -> List[MemoryEntry]: + """Search memories by vector similarity.""" + return [] + + def search_by_content( + self, content_query: Dict[str, Any], k: int = 5 + ) -> List[MemoryEntry]: + """Search memories by content.""" + return [] + + def contains(self, memory_id: str) -> bool: + """Check if a memory exists.""" + return memory_id in self.memories + + def update(self, memory: MemoryEntry) -> bool: + """Update a memory.""" + self.memories[memory["memory_id"]] = memory + return True + + def count(self) -> int: + """Count memories.""" + return len(self.memories) + + def count_by_type(self) -> Dict[str, int]: + """Count memories by type.""" + return {"state": 0, "action": 0, "interaction": 0} + + def clear(self) -> bool: + """Clear all memories.""" + self.memories.clear() + return True + + def test_memory_store_protocol_compliance(self): + """Test that MockMemoryStore complies with MemoryStore protocol.""" + store = self.MockMemoryStore() + + # This will fail if the MockMemoryStore doesn't properly implement MemoryStore + assert isinstance(store, MemoryStore) + + # Test basic functionality + test_memory: MemoryEntry = { + "memory_id": "test-mem", + "agent_id": "test-agent", + "step_number": 1, + "timestamp": 1000.0, + "contents": {"data": "test content"}, + "metadata": { + "importance_score": 0.5, + "memory_type": "state", + "creation_time": 1000.0, + "compression_level": 0, + "current_tier": "stm", + "last_access_time": 1000.0, + "retrieval_count": 0, + "checksum": "test" + }, + "embeddings": None + } + + assert store.update(test_memory) is True + assert store.contains("test-mem") is True + assert store.count() == 1 + retrieved = store.get("test-mem") + assert retrieved is not None + assert retrieved["memory_id"] == "test-mem" + assert store.clear() is True + assert store.count() == 0 \ No newline at end of file From c65d1b2f98705c96ec635c9920adcd886e1c5974 Mon Sep 17 00:00:00 2001 From: Chris Mangum Date: Fri, 2 May 2025 16:44:17 -0700 Subject: [PATCH 3/3] Refactor attribute validation and performance testing scripts This commit removes the old `validate_attribute.py` script and introduces a new `validate_attribute.py` for improved attribute search validation. Additionally, a new `performance_test_attribute.py` script is added to conduct performance tests on the attribute search strategy, evaluating speed, resource usage, and scalability under various load conditions. These changes enhance the testing framework and ensure more robust validation and performance metrics for the attribute search functionality. --- .../performance_test_attribute.py | 2 +- .../search/attribute/validate_attribute.py | 822 ++++++++++++++++++ validation/search/validate_attribute.py | 725 --------------- 3 files changed, 823 insertions(+), 726 deletions(-) rename validation/search/{ => attribute}/performance_test_attribute.py (99%) create mode 100644 validation/search/attribute/validate_attribute.py delete mode 100644 validation/search/validate_attribute.py diff --git a/validation/search/performance_test_attribute.py b/validation/search/attribute/performance_test_attribute.py similarity index 99% rename from validation/search/performance_test_attribute.py rename to validation/search/attribute/performance_test_attribute.py index cccec15..70c4cbc 100644 --- a/validation/search/performance_test_attribute.py +++ b/validation/search/attribute/performance_test_attribute.py @@ -21,7 +21,7 @@ import logging # Add project root to path -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) # Suppress checksum validation warnings - important to set this before imports logging.getLogger('memory.storage.redis_im').setLevel(logging.ERROR) diff --git a/validation/search/attribute/validate_attribute.py b/validation/search/attribute/validate_attribute.py new file mode 100644 index 0000000..4aa41e0 --- /dev/null +++ b/validation/search/attribute/validate_attribute.py @@ -0,0 +1,822 @@ +""" +Validation script for the Attribute Search Strategy. + +This script loads a predefined memory system and tests various scenarios +of attribute-based searching to verify the strategy works correctly. +""" + +import os +import sys +from typing import Any, Dict, List, Set + +# Add the project root to the path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) + +from validation.demo_utils import ( + create_memory_system, + log_print, + pretty_print_memories, + setup_logging, +) +from memory.search.strategies.attribute import AttributeSearchStrategy + +# Constants +AGENT_ID = "test-agent-attribute-search" +MEMORY_SAMPLE = os.path.join("memory_samples", "attribute_validation_memory.json") + +# Dictionary mapping memory IDs to their checksums for easier reference +MEMORY_CHECKSUMS = { + "meeting-123456-1": "0eb0f81d07276f08e05351a604d3c994564fedee3a93329e318186da517a3c56", + "meeting-123456-3": "f6ab36930459e74a52fdf21fb96a84241ccae3f6987365a21f9a17d84c5dae1e", + "meeting-123456-6": "ffa0ee60ebaec5574358a02d1857823e948519244e366757235bf755c888a87f", + "meeting-123456-9": "9214ebc2d11877665b32771bd3c080414d9519b435ec3f6c19cc5f337bb0ba90", + "meeting-123456-11": "ad2e7c963751beb1ebc1c9b84ecb09ec3ccdef14f276cd14bbebad12d0f9b0df", + "task-123456-2": "e0f7deb6929a17f65f56e5b03e16067c8bb65649fd2745f842aca7af701c9cac", + "task-123456-7": "1d23b6683acd8c3863cb2f2010fe3df2c3e69a2d94c7c4757a291d4872066cfd", + "task-123456-10": "f3c73b06d6399ed30ea9d9ad7c711a86dd58154809cc05497f8955425ec6dc67", + "note-123456-4": "1e9e265e75c2ef678dfd0de0ab5c801f845daa48a90a48bb02ee85148ccc3470", + "note-123456-8": "169c452e368fd62e3c0cf5ce7731769ed46ab6ae73e5048e0c3a7caaa66fba46", + "contact-123456-5": "496d09718bbc8ae669dffdd782ed5b849fdbb1a57e3f7d07e61807b10e650092", +} + + +def get_checksums_for_memory_ids(memory_ids: List[str]) -> Set[str]: + """Helper function to get checksums from memory IDs.""" + return { + MEMORY_CHECKSUMS[memory_id] + for memory_id in memory_ids + if memory_id in MEMORY_CHECKSUMS + } + + +def run_test( + search_strategy: AttributeSearchStrategy, + test_name: str, + query: Any, + agent_id: str, + limit: int = 10, + metadata_filter: Dict[str, Any] = None, + tier: str = None, + content_fields: List[str] = None, + metadata_fields: List[str] = None, + match_all: bool = False, + case_sensitive: bool = False, + use_regex: bool = False, + scoring_method: str = None, + expected_checksums: Set[str] = None, + expected_memory_ids: List[str] = None, +) -> Dict[str, Any]: + """Run a test case and return the results.""" + log_print(logger, f"\n=== Test: {test_name} ===") + + if isinstance(query, dict): + log_print(logger, f"Query (dict): {query}") + else: + log_print(logger, f"Query: '{query}'") + + log_print( + logger, + f"Match All: {match_all}, Case Sensitive: {case_sensitive}, Use Regex: {use_regex}", + ) + + if metadata_filter: + log_print(logger, f"Metadata Filter: {metadata_filter}") + + if tier: + log_print(logger, f"Tier: {tier}") + + if content_fields: + log_print(logger, f"Content Fields: {content_fields}") + + if metadata_fields: + log_print(logger, f"Metadata Fields: {metadata_fields}") + + if scoring_method: + log_print(logger, f"Scoring Method: {scoring_method}") + + # If expected_memory_ids is provided, convert to checksums + if expected_memory_ids and not expected_checksums: + expected_checksums = get_checksums_for_memory_ids(expected_memory_ids) + log_print( + logger, + f"Expecting {len(expected_checksums)} memories from specified memory IDs", + ) + + results = search_strategy.search( + query=query, + agent_id=agent_id, + limit=limit, + metadata_filter=metadata_filter, + tier=tier, + content_fields=content_fields, + metadata_fields=metadata_fields, + match_all=match_all, + case_sensitive=case_sensitive, + use_regex=use_regex, + scoring_method=scoring_method, + ) + + log_print(logger, f"Found {len(results)} results") + pretty_print_memories(results, f"Results for {test_name}", logger) + + # If we have scoring method, print the scores for comparison + if scoring_method and results: + log_print(logger, f"\nScores using {scoring_method} scoring method:") + for idx, result in enumerate(results[:5]): # Show scores for top 5 results + score = result.get("metadata", {}).get("attribute_score", 0) + memory_id = result.get("memory_id", result.get("id", f"Result {idx+1}")) + log_print(logger, f" {memory_id}: {score:.4f}") + + # Track test status + test_passed = True + + # Validate against expected checksums if provided + if expected_checksums: + result_checksums = { + result.get("metadata", {}).get("checksum", "") for result in results + } + missing_checksums = expected_checksums - result_checksums + unexpected_checksums = result_checksums - expected_checksums + + log_print(logger, f"\nValidation Results:") + if not missing_checksums and not unexpected_checksums: + log_print(logger, "All expected memories found. No unexpected memories.") + else: + if missing_checksums: + log_print(logger, f"Missing expected memories: {missing_checksums}") + test_passed = False + if unexpected_checksums: + log_print(logger, f"Found unexpected memories: {unexpected_checksums}") + test_passed = False + + log_print( + logger, + f"Expected: {len(expected_checksums)}, Found: {len(result_checksums)}, " + f"Missing: {len(missing_checksums)}, Unexpected: {len(unexpected_checksums)}", + ) + + return { + "results": results, + "test_name": test_name, + "passed": test_passed, + "has_validation": expected_checksums is not None + or expected_memory_ids is not None, + } + + +def validate_attribute_search(): + """Run validation tests for the attribute search strategy.""" + # Setup memory system + memory_system = create_memory_system( + logging_level="INFO", + memory_file=MEMORY_SAMPLE, + use_mock_redis=True, + ) + + # If memory system failed to load, exit + if not memory_system: + log_print(logger, "Failed to load memory system") + return + + # Setup search strategy + agent = memory_system.get_memory_agent(AGENT_ID) + search_strategy = AttributeSearchStrategy( + agent.stm_store, agent.im_store, agent.ltm_store + ) + + # Print strategy info + log_print(logger, f"Testing search strategy: {search_strategy.name()}") + log_print(logger, f"Description: {search_strategy.description()}") + + # Track test results + test_results = [] + + # Test 1: Basic content search + test_results.append( + run_test( + search_strategy, + "Basic Content Search", + "meeting", + AGENT_ID, + content_fields=["content.content"], + metadata_filter={"content.metadata.type": "meeting"}, + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + # Test 2: Case sensitive search + test_results.append( + run_test( + search_strategy, + "Case Sensitive Search", + "Meeting", + AGENT_ID, + case_sensitive=True, + content_fields=["content.content"], # Only search in content + metadata_filter={ + "content.metadata.type": "meeting" + }, # Only get meeting-type memories + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + ], + ) + ) + + # Test 3: Search by metadata type + test_results.append( + run_test( + search_strategy, + "Search by Metadata Type", + {"metadata": {"type": "meeting"}}, + AGENT_ID, + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + # Test 4: Search with match_all + test_results.append( + run_test( + search_strategy, + "Search with Match All", + { + "content": "meeting", + "metadata": {"type": "meeting", "importance": "high"}, + }, + AGENT_ID, + match_all=True, + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + # Test 5: Search specific memory tier + test_results.append( + run_test( + search_strategy, + "Search in STM Tier Only", + "meeting", + AGENT_ID, + tier="stm", + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + ], + ) + ) + + # Test 6: Search with regex + test_results.append( + run_test( + search_strategy, + "Regex Search", + "secur.*patch", + AGENT_ID, + use_regex=True, + expected_memory_ids=[ + "note-123456-4", + ], + ) + ) + + # Test 7: Search with metadata filter + test_results.append( + run_test( + search_strategy, + "Search with Metadata Filter", + "meeting", + AGENT_ID, + metadata_filter={"content.metadata.importance": "high"}, + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + # Test 8: Search in specific content fields + test_results.append( + run_test( + search_strategy, + "Search in Specific Content Fields", + "project", + AGENT_ID, + content_fields=["content.content"], + expected_memory_ids=[ + "meeting-123456-1", + "contact-123456-5", + ], + ) + ) + + # Test 9: Search in specific metadata fields + test_results.append( + run_test( + search_strategy, + "Search in Specific Metadata Fields", + "project", + AGENT_ID, + metadata_fields=["content.metadata.tags"], + expected_memory_ids=[ + "meeting-123456-1", + "contact-123456-5", + ], + ) + ) + + # Test 10: Search with complex query and filters + test_results.append( + run_test( + search_strategy, + "Complex Search", + {"content": "security", "metadata": {"importance": "high"}}, + AGENT_ID, + metadata_filter={"content.metadata.source": "email"}, + match_all=True, + expected_memory_ids=[ + "note-123456-4", + ], + ) + ) + + # Test 11: Empty query handling - string + test_results.append( + run_test( + search_strategy, + "Empty String Query", + "", + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 12: Empty query handling - dict + test_results.append( + run_test( + search_strategy, + "Empty Dict Query", + {}, + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 13: Numeric value search + test_results.append( + run_test( + search_strategy, + "Numeric Value Search", + 42, + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 14: Boolean value search + test_results.append( + run_test( + search_strategy, + "Boolean Value Search", + {"metadata": {"completed": True}}, + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 15: Type conversion - searching string with numeric + test_results.append( + run_test( + search_strategy, + "Type Conversion - String Field with Numeric", + 123, + AGENT_ID, + content_fields=["content.content"], + expected_memory_ids=[], + ) + ) + + # Test 16: Invalid regex pattern handling + test_results.append( + run_test( + search_strategy, + "Invalid Regex Pattern", + "[unclosed-bracket", + AGENT_ID, + use_regex=True, + expected_memory_ids=[], + ) + ) + + # Test 17: Array field partial matching + test_results.append( + run_test( + search_strategy, + "Array Field Partial Matching", + "dev", + AGENT_ID, + metadata_fields=["content.metadata.tags"], + expected_memory_ids=[ + "meeting-123456-3", + "task-123456-10", + ], + ) + ) + + # Test 18: Special characters in search + test_results.append( + run_test( + search_strategy, + "Special Characters in Search", + "meeting+notes", + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 19: Multi-tier search + test_results.append( + run_test( + search_strategy, + "Multi-Tier Search", + "important", + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 20: Large result set limiting + test_results.append( + run_test( + search_strategy, + "Large Result Set Limiting", + "a", # Common letter to match many memories + AGENT_ID, + limit=3, # Only show top 3 results + expected_memory_ids=[ + "meeting-123456-1", # Contains "about" and "allocation" - shorter content with multiple 'a's + "meeting-123456-6", # Contains "about" and "roadmap" - shorter content with multiple 'a's + "meeting-123456-3", # Contains "authentication" and "team" - longer content with fewer 'a's + ], + ) + ) + + # ===== New tests for scoring methods ===== + log_print(logger, "\n=== SCORING METHOD COMPARISON TESTS ===") + + # Test 21: Comparing scoring methods on the same query + test_query = "meeting" + log_print(logger, f"\nComparing scoring methods for query: '{test_query}'") + + # Try each scoring method and collect results + test_results.append( + run_test( + search_strategy, + "Default Length Ratio Scoring", + test_query, + AGENT_ID, + limit=5, + scoring_method="length_ratio", + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "Term Frequency Scoring", + test_query, + AGENT_ID, + limit=5, + scoring_method="term_frequency", + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "BM25 Scoring", + test_query, + AGENT_ID, + limit=5, + scoring_method="bm25", + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "Binary Scoring", + test_query, + AGENT_ID, + limit=5, + scoring_method="binary", + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + # Test 22: Testing scoring on a document with repeated terms + test_with_repetition_query = "security" # Look for security-related memories + log_print( + logger, + f"\nComparing scoring methods for query with potential term repetition: '{test_with_repetition_query}'", + ) + + # Test with default length ratio scoring + test_results.append( + run_test( + search_strategy, + "Default Scoring with Term Repetition", + test_with_repetition_query, + AGENT_ID, + limit=5, + expected_memory_ids=[ + "note-123456-4", + "note-123456-8", + "meeting-123456-11", + ], + ) + ) + + # Test with term frequency scoring - should favor documents with more occurrences + test_results.append( + run_test( + search_strategy, + "Term Frequency with Term Repetition", + test_with_repetition_query, + AGENT_ID, + limit=5, + scoring_method="term_frequency", + expected_memory_ids=[ + "note-123456-4", + "note-123456-8", + "meeting-123456-11", + ], + ) + ) + + # Test with BM25 scoring - balances term frequency and document length + test_results.append( + run_test( + search_strategy, + "BM25 with Term Repetition", + test_with_repetition_query, + AGENT_ID, + limit=5, + scoring_method="bm25", + expected_memory_ids=[ + "note-123456-4", + "note-123456-8", + "meeting-123456-11", + ], + ) + ) + + # Test 23: Testing with a specialized search strategy for each method + log_print( + logger, + "\nTesting with dedicated search strategy instances for each scoring method", + ) + + # Create specialized strategy instances + term_freq_strategy = AttributeSearchStrategy( + agent.stm_store, + agent.im_store, + agent.ltm_store, + scoring_method="term_frequency", + ) + + bm25_strategy = AttributeSearchStrategy( + agent.stm_store, + agent.im_store, + agent.ltm_store, + scoring_method="bm25", + ) + + # Run test with specialized strategies + test_results.append( + run_test( + term_freq_strategy, + "Using Term Frequency Strategy Instance", + "project", + AGENT_ID, + limit=5, + expected_memory_ids=[ + "meeting-123456-1", + "contact-123456-5", + ], + ) + ) + + test_results.append( + run_test( + bm25_strategy, + "Using BM25 Strategy Instance", + "project", + AGENT_ID, + limit=5, + expected_memory_ids=[ + "meeting-123456-1", + "contact-123456-5", + ], + ) + ) + + # Test 24: Testing with a long document vs short document comparison + # Change from "detailed" (no matches) to "authentication system" (appears in memories of different lengths) + long_doc_query = "authentication system" + log_print( + logger, + f"\nComparing scoring methods for long vs short document query: '{long_doc_query}'", + ) + + # Compare each scoring method + test_results.append( + run_test( + search_strategy, + "Length Ratio for Long Documents", + long_doc_query, + AGENT_ID, + limit=5, + scoring_method="length_ratio", + expected_memory_ids=[ + "meeting-123456-3", + "task-123456-7", + "task-123456-10", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "Term Frequency for Long Documents", + long_doc_query, + AGENT_ID, + limit=5, + scoring_method="term_frequency", + expected_memory_ids=[ + "meeting-123456-3", + "task-123456-7", + "task-123456-10", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "BM25 for Long Documents", + long_doc_query, + AGENT_ID, + limit=5, + scoring_method="bm25", + expected_memory_ids=[ + "meeting-123456-3", + "task-123456-7", + "task-123456-10", + ], + ) + ) + + # Test 25: Testing with a query that matches varying document length and context + varying_length_query = "documentation" + log_print( + logger, + f"\nComparing scoring methods for documents of varying lengths: '{varying_length_query}'", + ) + + test_results.append( + run_test( + search_strategy, + "Length Ratio for Documentation Query", + varying_length_query, + AGENT_ID, + limit=5, + scoring_method="length_ratio", + expected_memory_ids=[ + "task-123456-2", + "task-123456-7", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "Term Frequency for Documentation Query", + varying_length_query, + AGENT_ID, + limit=5, + scoring_method="term_frequency", + expected_memory_ids=[ + "task-123456-2", + "task-123456-7", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "BM25 for Documentation Query", + varying_length_query, + AGENT_ID, + limit=5, + scoring_method="bm25", + expected_memory_ids=[ + "task-123456-2", + "task-123456-7", + ], + ) + ) + + # Display validation summary + log_print(logger, "\n\n=== VALIDATION SUMMARY ===") + log_print(logger, "-" * 80) + log_print( + logger, + "| {:<40} | {:<20} | {:<20} |".format( + "Test Name", "Status", "Validation Status" + ), + ) + log_print(logger, "-" * 80) + + for result in test_results: + status = "PASS" if result["passed"] else "FAIL" + validation_status = status if result["has_validation"] else "N/A" + log_print( + logger, + "| {:<40} | {:<20} | {:<20} |".format( + result["test_name"][:40], status, validation_status + ), + ) + + log_print(logger, "-" * 80) + + # Calculate overall statistics + validated_tests = [t for t in test_results if t["has_validation"]] + passed_tests = [t for t in validated_tests if t["passed"]] + + if validated_tests: + success_rate = len(passed_tests) / len(validated_tests) * 100 + log_print(logger, f"\nValidated Tests: {len(validated_tests)}") + log_print(logger, f"Passed Tests: {len(passed_tests)}") + log_print(logger, f"Failed Tests: {len(validated_tests) - len(passed_tests)}") + log_print(logger, f"Success Rate: {success_rate:.2f}%") + else: + log_print(logger, "\nNo tests with validation criteria were run.") + + +if __name__ == "__main__": + # Setup logging + logger = setup_logging("validate_attribute_search") + log_print(logger, "Starting Attribute Search Strategy Validation") + + validate_attribute_search() + + log_print(logger, "Validation Complete") diff --git a/validation/search/validate_attribute.py b/validation/search/validate_attribute.py deleted file mode 100644 index f6bf69d..0000000 --- a/validation/search/validate_attribute.py +++ /dev/null @@ -1,725 +0,0 @@ -""" -Validation script for the Attribute Search Strategy. - -This script loads a predefined memory system and tests various scenarios -of attribute-based searching to verify the strategy works correctly. -""" - -import os -import sys -from typing import Any, Dict, List, Set - -# Add the project root to the path -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) - -from validation.demo_utils import ( - create_memory_system, - log_print, - pretty_print_memories, - setup_logging, -) -from memory.search.strategies.attribute import AttributeSearchStrategy - -# Constants -AGENT_ID = "test-agent-attribute-search" -MEMORY_SAMPLE = "attribute_validation_memory.json" - -# Dictionary mapping memory IDs to their checksums for easier reference -MEMORY_CHECKSUMS = { - "meeting-123456-1": "0eb0f81d07276f08e05351a604d3c994564fedee3a93329e318186da517a3c56", - "meeting-123456-3": "f6ab36930459e74a52fdf21fb96a84241ccae3f6987365a21f9a17d84c5dae1e", - "meeting-123456-6": "ffa0ee60ebaec5574358a02d1857823e948519244e366757235bf755c888a87f", - "meeting-123456-9": "9214ebc2d11877665b32771bd3c080414d9519b435ec3f6c19cc5f337bb0ba90", - "meeting-123456-11": "ad2e7c963751beb1ebc1c9b84ecb09ec3ccdef14f276cd14bbebad12d0f9b0df", - "task-123456-2": "e0f7deb6929a17f65f56e5b03e16067c8bb65649fd2745f842aca7af701c9cac", - "task-123456-7": "1d23b6683acd8c3863cb2f2010fe3df2c3e69a2d94c7c4757a291d4872066cfd", - "task-123456-10": "f3c73b06d6399ed30ea9d9ad7c711a86dd58154809cc05497f8955425ec6dc67", - "note-123456-4": "1e9e265e75c2ef678dfd0de0ab5c801f845daa48a90a48bb02ee85148ccc3470", - "note-123456-8": "169c452e368fd62e3c0cf5ce7731769ed46ab6ae73e5048e0c3a7caaa66fba46", - "contact-123456-5": "496d09718bbc8ae669dffdd782ed5b849fdbb1a57e3f7d07e61807b10e650092", -} - - -def get_checksums_for_memory_ids(memory_ids: List[str]) -> Set[str]: - """Helper function to get checksums from memory IDs.""" - return { - MEMORY_CHECKSUMS[memory_id] - for memory_id in memory_ids - if memory_id in MEMORY_CHECKSUMS - } - - -def run_test( - search_strategy: AttributeSearchStrategy, - test_name: str, - query: Any, - agent_id: str, - limit: int = 10, - metadata_filter: Dict[str, Any] = None, - tier: str = None, - content_fields: List[str] = None, - metadata_fields: List[str] = None, - match_all: bool = False, - case_sensitive: bool = False, - use_regex: bool = False, - scoring_method: str = None, - expected_checksums: Set[str] = None, - expected_memory_ids: List[str] = None, -) -> Dict[str, Any]: - """Run a test case and return the results.""" - log_print(logger, f"\n=== Test: {test_name} ===") - - if isinstance(query, dict): - log_print(logger, f"Query (dict): {query}") - else: - log_print(logger, f"Query: '{query}'") - - log_print( - logger, - f"Match All: {match_all}, Case Sensitive: {case_sensitive}, Use Regex: {use_regex}", - ) - - if metadata_filter: - log_print(logger, f"Metadata Filter: {metadata_filter}") - - if tier: - log_print(logger, f"Tier: {tier}") - - if content_fields: - log_print(logger, f"Content Fields: {content_fields}") - - if metadata_fields: - log_print(logger, f"Metadata Fields: {metadata_fields}") - - if scoring_method: - log_print(logger, f"Scoring Method: {scoring_method}") - - # If expected_memory_ids is provided, convert to checksums - if expected_memory_ids and not expected_checksums: - expected_checksums = get_checksums_for_memory_ids(expected_memory_ids) - log_print( - logger, - f"Expecting {len(expected_checksums)} memories from specified memory IDs", - ) - - results = search_strategy.search( - query=query, - agent_id=agent_id, - limit=limit, - metadata_filter=metadata_filter, - tier=tier, - content_fields=content_fields, - metadata_fields=metadata_fields, - match_all=match_all, - case_sensitive=case_sensitive, - use_regex=use_regex, - scoring_method=scoring_method, - ) - - log_print(logger, f"Found {len(results)} results") - pretty_print_memories(results, f"Results for {test_name}", logger) - - # If we have scoring method, print the scores for comparison - if scoring_method and results: - log_print(logger, f"\nScores using {scoring_method} scoring method:") - for idx, result in enumerate(results[:5]): # Show scores for top 5 results - score = result.get("metadata", {}).get("attribute_score", 0) - memory_id = result.get("memory_id", result.get("id", f"Result {idx+1}")) - log_print(logger, f" {memory_id}: {score:.4f}") - - # Track test status - test_passed = True - - # Validate against expected checksums if provided - if expected_checksums: - result_checksums = { - result.get("metadata", {}).get("checksum", "") for result in results - } - missing_checksums = expected_checksums - result_checksums - unexpected_checksums = result_checksums - expected_checksums - - log_print(logger, f"\nValidation Results:") - if not missing_checksums and not unexpected_checksums: - log_print(logger, "All expected memories found. No unexpected memories.") - else: - if missing_checksums: - log_print(logger, f"Missing expected memories: {missing_checksums}") - test_passed = False - if unexpected_checksums: - log_print(logger, f"Found unexpected memories: {unexpected_checksums}") - test_passed = False - - log_print( - logger, - f"Expected: {len(expected_checksums)}, Found: {len(result_checksums)}, " - f"Missing: {len(missing_checksums)}, Unexpected: {len(unexpected_checksums)}", - ) - - return { - "results": results, - "test_name": test_name, - "passed": test_passed, - "has_validation": expected_checksums is not None - } - - -def validate_attribute_search(): - """Run validation tests for the attribute search strategy.""" - # Setup memory system - memory_system = create_memory_system( - logging_level="INFO", - memory_file=MEMORY_SAMPLE, - use_mock_redis=True, - ) - - # If memory system failed to load, exit - if not memory_system: - log_print(logger, "Failed to load memory system") - return - - # Setup search strategy - agent = memory_system.get_memory_agent(AGENT_ID) - search_strategy = AttributeSearchStrategy( - agent.stm_store, agent.im_store, agent.ltm_store - ) - - # Print strategy info - log_print(logger, f"Testing search strategy: {search_strategy.name()}") - log_print(logger, f"Description: {search_strategy.description()}") - - # Track test results - test_results = [] - - # Test 1: Basic content search - test_results.append(run_test( - search_strategy, - "Basic Content Search", - "meeting", - AGENT_ID, - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - # Test 2: Case sensitive search - test_results.append(run_test( - search_strategy, - "Case Sensitive Search", - "Meeting", - AGENT_ID, - case_sensitive=True, - )) - - # Test 3: Search by metadata type - test_results.append(run_test( - search_strategy, - "Search by Metadata Type", - {"metadata": {"type": "meeting"}}, - AGENT_ID, - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - # Test 4: Search with match_all - test_results.append(run_test( - search_strategy, - "Search with Match All", - {"content": "meeting", "metadata": {"type": "meeting", "importance": "high"}}, - AGENT_ID, - match_all=True, - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - # Test 5: Search specific memory tier - test_results.append(run_test( - search_strategy, - "Search in STM Tier Only", - "meeting", - AGENT_ID, - tier="stm", - expected_memory_ids=["meeting-123456-1", "meeting-123456-3"], - )) - - # Test 6: Search with regex - test_results.append(run_test( - search_strategy, - "Regex Search", - "secur.*patch", - AGENT_ID, - use_regex=True, - expected_memory_ids=["note-123456-4"], - )) - - # Test 7: Search with metadata filter - test_results.append(run_test( - search_strategy, - "Search with Metadata Filter", - "meeting", - AGENT_ID, - metadata_filter={"content.metadata.importance": "high"}, - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - # Test 8: Search in specific content fields - test_results.append(run_test( - search_strategy, - "Search in Specific Content Fields", - "project", - AGENT_ID, - content_fields=["content.content"], - expected_memory_ids=[ - "meeting-123456-1", - "contact-123456-5", - ], - )) - - # Test 9: Search in specific metadata fields - test_results.append(run_test( - search_strategy, - "Search in Specific Metadata Fields", - "project", - AGENT_ID, - metadata_fields=["content.metadata.tags"], - expected_memory_ids=[ - "meeting-123456-1", - "contact-123456-5", - ], - )) - - # Test 10: Search with complex query and filters - test_results.append(run_test( - search_strategy, - "Complex Search", - {"content": "security", "metadata": {"importance": "high"}}, - AGENT_ID, - metadata_filter={"content.metadata.source": "email"}, - match_all=True, - expected_memory_ids=["note-123456-4"], - )) - - # Test 11: Empty query handling - string - test_results.append(run_test( - search_strategy, - "Empty String Query", - "", - AGENT_ID, - expected_memory_ids=[], - )) - - # Test 12: Empty query handling - dict - test_results.append(run_test( - search_strategy, - "Empty Dict Query", - {}, - AGENT_ID, - expected_memory_ids=[], - )) - - # Test 13: Numeric value search - test_results.append(run_test( - search_strategy, - "Numeric Value Search", - 42, - AGENT_ID, - expected_memory_ids=[], - )) - - # Test 14: Boolean value search - test_results.append(run_test( - search_strategy, - "Boolean Value Search", - {"metadata": {"completed": True}}, - AGENT_ID, - expected_memory_ids=[], - )) - - # Test 15: Type conversion - searching string with numeric - test_results.append(run_test( - search_strategy, - "Type Conversion - String Field with Numeric", - 123, - AGENT_ID, - content_fields=["content.content"], - expected_memory_ids=[], - )) - - # Test 16: Invalid regex pattern handling - test_results.append(run_test( - search_strategy, - "Invalid Regex Pattern", - "[unclosed-bracket", - AGENT_ID, - use_regex=True, - expected_memory_ids=[], - )) - - # Test 17: Array field partial matching - test_results.append(run_test( - search_strategy, - "Array Field Partial Matching", - "dev", - AGENT_ID, - metadata_fields=["content.metadata.tags"], - expected_memory_ids=[ - "meeting-123456-3", - "task-123456-10", - ], - )) - - # Test 18: Special characters in search - test_results.append(run_test( - search_strategy, - "Special Characters in Search", - "meeting+notes", - AGENT_ID, - expected_memory_ids=[], - )) - - # Test 19: Multi-tier search - test_results.append(run_test( - search_strategy, - "Multi-Tier Search", - "important", - AGENT_ID, - # No tier specified means searching all tiers - expected_memory_ids=[], - )) - - # Test 20: Large result set limiting - test_results.append(run_test( - search_strategy, - "Large Result Set Limiting", - "a", # Common letter to match many memories - AGENT_ID, - limit=3, # Only show top 3 results - expected_memory_ids=[ - "meeting-123456-1", - "task-123456-2", - "meeting-123456-3", - ], - )) - - # ===== New tests for scoring methods ===== - log_print(logger, "\n=== SCORING METHOD COMPARISON TESTS ===") - - # Test 21: Comparing scoring methods on the same query - test_query = "meeting" - log_print(logger, f"\nComparing scoring methods for query: '{test_query}'") - - # Try each scoring method and collect results - test_results.append(run_test( - search_strategy, - "Default Length Ratio Scoring", - test_query, - AGENT_ID, - limit=5, - scoring_method="length_ratio", - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - test_results.append(run_test( - search_strategy, - "Term Frequency Scoring", - test_query, - AGENT_ID, - limit=5, - scoring_method="term_frequency", - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - test_results.append(run_test( - search_strategy, - "BM25 Scoring", - test_query, - AGENT_ID, - limit=5, - scoring_method="bm25", - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - test_results.append(run_test( - search_strategy, - "Binary Scoring", - test_query, - AGENT_ID, - limit=5, - scoring_method="binary", - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - # Test 22: Testing scoring on a document with repeated terms - test_with_repetition_query = "security" # Look for security-related memories - log_print( - logger, - f"\nComparing scoring methods for query with potential term repetition: '{test_with_repetition_query}'", - ) - - # Test with default length ratio scoring - test_results.append(run_test( - search_strategy, - "Default Scoring with Term Repetition", - test_with_repetition_query, - AGENT_ID, - limit=5, - expected_memory_ids=[ - "note-123456-4", - "note-123456-8", - "meeting-123456-11", - ], - )) - - # Test with term frequency scoring - should favor documents with more occurrences - test_results.append(run_test( - search_strategy, - "Term Frequency with Term Repetition", - test_with_repetition_query, - AGENT_ID, - limit=5, - scoring_method="term_frequency", - expected_memory_ids=[ - "note-123456-4", - "note-123456-8", - "meeting-123456-11", - ], - )) - - # Test with BM25 scoring - balances term frequency and document length - test_results.append(run_test( - search_strategy, - "BM25 with Term Repetition", - test_with_repetition_query, - AGENT_ID, - limit=5, - scoring_method="bm25", - expected_memory_ids=[ - "note-123456-4", - "note-123456-8", - "meeting-123456-11", - ], - )) - - # Test 23: Testing with a specialized search strategy for each method - log_print( - logger, - "\nTesting with dedicated search strategy instances for each scoring method", - ) - - # Create specialized strategy instances - term_freq_strategy = AttributeSearchStrategy( - agent.stm_store, - agent.im_store, - agent.ltm_store, - scoring_method="term_frequency", - ) - - bm25_strategy = AttributeSearchStrategy( - agent.stm_store, - agent.im_store, - agent.ltm_store, - scoring_method="bm25", - ) - - # Run test with specialized strategies - test_results.append(run_test( - term_freq_strategy, - "Using Term Frequency Strategy Instance", - "project", - AGENT_ID, - limit=5, - expected_memory_ids=[ - "meeting-123456-1", - "contact-123456-5", - ], - )) - - test_results.append(run_test( - bm25_strategy, - "Using BM25 Strategy Instance", - "project", - AGENT_ID, - limit=5, - expected_memory_ids=[ - "meeting-123456-1", - "contact-123456-5", - ], - )) - - # Test 24: Testing with a long document vs short document comparison - # Change from "detailed" (no matches) to "authentication system" (appears in memories of different lengths) - long_doc_query = "authentication system" - log_print( - logger, - f"\nComparing scoring methods for long vs short document query: '{long_doc_query}'", - ) - - # Compare each scoring method - test_results.append(run_test( - search_strategy, - "Length Ratio for Long Documents", - long_doc_query, - AGENT_ID, - limit=5, - scoring_method="length_ratio", - expected_memory_ids=[ - "meeting-123456-3", - "task-123456-7", - "task-123456-10", - ], - )) - - test_results.append(run_test( - search_strategy, - "Term Frequency for Long Documents", - long_doc_query, - AGENT_ID, - limit=5, - scoring_method="term_frequency", - expected_memory_ids=[ - "meeting-123456-3", - "task-123456-7", - "task-123456-10", - ], - )) - - test_results.append(run_test( - search_strategy, - "BM25 for Long Documents", - long_doc_query, - AGENT_ID, - limit=5, - scoring_method="bm25", - expected_memory_ids=[ - "meeting-123456-3", - "task-123456-7", - "task-123456-10", - ], - )) - - # Test 25: Testing with a query that matches varying document length and context - varying_length_query = "documentation" - log_print( - logger, - f"\nComparing scoring methods for documents of varying lengths: '{varying_length_query}'", - ) - - test_results.append(run_test( - search_strategy, - "Length Ratio for Documentation Query", - varying_length_query, - AGENT_ID, - limit=5, - scoring_method="length_ratio", - expected_memory_ids=[ - "task-123456-2", - "task-123456-7", - ], - )) - - test_results.append(run_test( - search_strategy, - "Term Frequency for Documentation Query", - varying_length_query, - AGENT_ID, - limit=5, - scoring_method="term_frequency", - expected_memory_ids=[ - "task-123456-2", - "task-123456-7", - ], - )) - - test_results.append(run_test( - search_strategy, - "BM25 for Documentation Query", - varying_length_query, - AGENT_ID, - limit=5, - scoring_method="bm25", - expected_memory_ids=[ - "task-123456-2", - "task-123456-7", - ], - )) - - # Display validation summary - log_print(logger, "\n\n=== VALIDATION SUMMARY ===") - log_print(logger, "-" * 80) - log_print(logger, "| {:<40} | {:<20} | {:<20} |".format("Test Name", "Status", "Validation Status")) - log_print(logger, "-" * 80) - - for result in test_results: - status = "PASS" if result["passed"] else "FAIL" - validation_status = status if result["has_validation"] else "N/A" - log_print(logger, "| {:<40} | {:<20} | {:<20} |".format( - result["test_name"][:40], - status, - validation_status - )) - - log_print(logger, "-" * 80) - - # Calculate overall statistics - validated_tests = [t for t in test_results if t["has_validation"]] - passed_tests = [t for t in validated_tests if t["passed"]] - - if validated_tests: - success_rate = len(passed_tests) / len(validated_tests) * 100 - log_print(logger, f"\nValidated Tests: {len(validated_tests)}") - log_print(logger, f"Passed Tests: {len(passed_tests)}") - log_print(logger, f"Failed Tests: {len(validated_tests) - len(passed_tests)}") - log_print(logger, f"Success Rate: {success_rate:.2f}%") - else: - log_print(logger, "\nNo tests with validation criteria were run.") - - -if __name__ == "__main__": - # Setup logging - logger = setup_logging("validate_attribute_search") - log_print(logger, "Starting Attribute Search Strategy Validation") - - validate_attribute_search() - - log_print(logger, "Validation Complete")