diff --git a/memory/api/README.md b/memory/api/README.md new file mode 100644 index 0000000..3414e89 --- /dev/null +++ b/memory/api/README.md @@ -0,0 +1,185 @@ +# Agent Memory API + +The `memory/api` module provides a comprehensive interface for integrating the Agent Memory System with AI agents, enabling efficient storage, retrieval, and utilization of memories to support context-aware agent reasoning and behavior. + +## Overview + +The Agent Memory System implements a tiered memory architecture inspired by human cognitive systems: + +- **Short-Term Memory (STM)**: Recent, high-fidelity memories with detailed information +- **Intermediate Memory (IM)**: Medium-term memories with moderate compression +- **Long-Term Memory (LTM)**: Persistent, compressed memories retaining core information + +This architecture allows agents to efficiently manage memories with different levels of detail and importance across varying time horizons. + +## Module Components + +### Main API Class + +[`memory_api.py`](memory_api.py) - Provides the primary interface class `AgentMemoryAPI` for interacting with the memory system: +- Store agent states, actions, and interactions +- Retrieve memories by various criteria (ID, time range, attributes) +- Perform semantic search across memory tiers +- Manage memory lifecycle and maintenance + +### Agent Integration + +[`hooks.py`](hooks.py) - Offers decorators and utility functions for automatic memory integration: +- `install_memory_hooks`: Class decorator to add memory capabilities to agent classes +- `with_memory`: Instance decorator for adding memory to existing agent instances +- `BaseAgent`: Minimal interface with standard lifecycle methods for memory-aware agents + +### Data Models + +[`models.py`](models.py) - Defines structured representations of agent data: +- `AgentState`: Standardized representation of an agent's state +- `ActionData`: Record of an agent action with associated states and metrics +- `ActionResult`: Lightweight result of an action execution + +### Type Definitions + +[`types.py`](types.py) - Establishes core type definitions for the memory system: +- Memory entry structures (metadata, embeddings, content) +- Memory tiers and filtering types +- Statistics and query result types +- Protocol definitions for memory stores + +## Getting Started + +### Basic Usage + +```python +from memory.api import AgentMemoryAPI + +# Initialize the memory API +memory_api = AgentMemoryAPI() + +# Store an agent state +state_data = { + "agent_id": "agent-001", + "step_number": 42, + "content": { + "observation": "User asked about weather", + "thought": "I should check the forecast" + } +} +memory_id = memory_api.store_agent_state("agent-001", state_data, step_number=42) + +# Retrieve similar memories +query = "weather forecast" +similar_memories = memory_api.search_by_content("agent-001", query, k=3) + +# Use memories to inform agent's response +for memory in similar_memories: + print(f"Related memory: {memory['contents']}") +``` + +### Automatic Memory Integration + +```python +from memory.api import install_memory_hooks, BaseAgent + +@install_memory_hooks +class MyAgent(BaseAgent): + def __init__(self, config=None, agent_id=None): + super().__init__(config, agent_id) + # Agent-specific initialization + + def act(self, observation): + # Memory hooks automatically capture state before this method + self.step_number += 1 + action_result = self._process(observation) + # Memory hooks automatically capture state after this method + return action_result + + def get_state(self): + # Return current agent state + state = super().get_state() + state.extra_data["custom_field"] = self.some_internal_state + return state + +# Create an agent with memory enabled +agent = MyAgent(agent_id="agent-001") + +# Use the agent normally - memories are created automatically +result = agent.act({"user_input": "What's the weather today?"}) +``` + +## Advanced Features + +### Memory Maintenance + +```python +# Run memory maintenance to consolidate and optimize memories +memory_api.force_memory_maintenance("agent-001") + +# Get memory statistics +stats = memory_api.get_memory_statistics("agent-001") +print(f"Total memories: {stats['total_memories']}") +print(f"STM: {stats['stm_count']}, IM: {stats['im_count']}, LTM: {stats['ltm_count']}") +``` + +### Memory Search + +```python +# Search by content similarity +similar_memories = memory_api.search_by_content( + agent_id="agent-001", + content_query="user asked about calendar appointments", + k=5 +) + +# Retrieve memories by time range +recent_memories = memory_api.retrieve_by_time_range( + agent_id="agent-001", + start_step=100, + end_step=120 +) + +# Retrieve memories by attributes +filtered_memories = memory_api.retrieve_by_attributes( + agent_id="agent-001", + attributes={"action_type": "calendar_query"} +) +``` + +### Memory Configuration + +```python +from memory.api import AgentMemoryAPI +from memory.config import MemoryConfig + +# Custom configuration +config = MemoryConfig( + stm_config={"memory_limit": 1000}, + im_config={"memory_limit": 10000}, + ltm_config={"memory_limit": 100000} +) + +# Initialize API with custom configuration +memory_api = AgentMemoryAPI(config) + +# Update configuration +memory_api.configure_memory_system({ + "stm_config": {"memory_limit": 2000} +}) +``` + +## Error Handling + +```python +from memory.api import AgentMemoryAPI +from memory.api.memory_api import MemoryStoreException, MemoryRetrievalException + +memory_api = AgentMemoryAPI() + +try: + memory = memory_api.retrieve_state_by_id("agent-001", "non_existent_id") +except MemoryRetrievalException as e: + print(f"Memory retrieval error: {e}") + +try: + memories = memory_api.search_by_content("agent-001", "query", k=-1) +except MemoryConfigException as e: + print(f"Configuration error: {e}") +``` \ No newline at end of file diff --git a/memory/search/strategies/attribute.py b/memory/search/strategies/attribute.py index 04ff0af..9b8d854 100644 --- a/memory/search/strategies/attribute.py +++ b/memory/search/strategies/attribute.py @@ -33,7 +33,7 @@ def __init__( im_store: RedisIMStore, ltm_store: SQLiteLTMStore, scoring_method: str = "length_ratio", - skip_validation: bool = False, + skip_validation: bool = True, ): """Initialize the attribute search strategy. @@ -146,7 +146,7 @@ def search( case_sensitive: bool = False, use_regex: bool = False, scoring_method: Optional[str] = None, - skip_validation: Optional[bool] = None, + skip_validation: Optional[bool] = True, **kwargs, ) -> List[Dict[str, Any]]: """Search for memories based on content and metadata attributes. diff --git a/memory/storage/redis_im.py b/memory/storage/redis_im.py index f8a26fd..fe7f7d2 100644 --- a/memory/storage/redis_im.py +++ b/memory/storage/redis_im.py @@ -480,7 +480,7 @@ def _store_memory_entry(self, agent_id: str, memory_entry: MemoryEntry) -> bool: ) return False - def get(self, agent_id: str, memory_id: str, skip_validation: bool = False) -> Optional[MemoryEntry]: + def get(self, agent_id: str, memory_id: str, skip_validation: bool = True) -> Optional[MemoryEntry]: """Retrieve a memory entry by ID. Args: @@ -616,7 +616,7 @@ def _hash_to_memory_entry(self, hash_data: Dict[str, Any]) -> MemoryEntry: return memory_entry def get_by_timerange( - self, agent_id: str, start_time: float, end_time: float, limit: int = 100, skip_validation: bool = False + self, agent_id: str, start_time: float, end_time: float, limit: int = 100, skip_validation: bool = True ) -> List[MemoryEntry]: """Retrieve memories within a time range. @@ -713,7 +713,7 @@ def get_by_importance( min_importance: float = 0.0, max_importance: float = 1.0, limit: int = 100, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[MemoryEntry]: """Retrieve memories by importance score range. @@ -1341,7 +1341,7 @@ def get_size(self, agent_id: str) -> int: logger.exception("Error retrieving memory size for agent %s", agent_id) return 0 - def get_all(self, agent_id: str, limit: int = 1000, skip_validation: bool = False) -> List[MemoryEntry]: + def get_all(self, agent_id: str, limit: int = 1000, skip_validation: bool = True) -> List[MemoryEntry]: """Get all memories for an agent. Args: diff --git a/memory/storage/redis_stm.py b/memory/storage/redis_stm.py index 43aa4d6..fff35b3 100644 --- a/memory/storage/redis_stm.py +++ b/memory/storage/redis_stm.py @@ -273,7 +273,7 @@ def _store_memory_entry(self, agent_id: str, memory_entry: MemoryEntry) -> bool: ) return False - def get(self, agent_id: str, memory_id: str, skip_validation: bool = False) -> Optional[MemoryEntry]: + def get(self, agent_id: str, memory_id: str, skip_validation: bool = True) -> Optional[MemoryEntry]: """Retrieve a memory entry by ID. Args: @@ -378,7 +378,7 @@ def _update_access_metadata( ) def get_by_timerange( - self, agent_id: str, start_time: float, end_time: float, limit: int = 100, skip_validation: bool = False + self, agent_id: str, start_time: float, end_time: float, limit: int = 100, skip_validation: bool = True ) -> List[MemoryEntry]: """Retrieve memories within a time range. @@ -425,7 +425,7 @@ def get_by_importance( min_importance: float = 0.0, max_importance: float = 1.0, limit: int = 100, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[MemoryEntry]: """Retrieve memories by importance score. @@ -641,7 +641,7 @@ def get_size(self, agent_id: str) -> int: logger.error("Error calculating memory size: %s", e) return 0 - def get_all(self, agent_id: str, limit: int = 1000, skip_validation: bool = False) -> List[MemoryEntry]: + def get_all(self, agent_id: str, limit: int = 1000, skip_validation: bool = True) -> List[MemoryEntry]: """Get all memories for an agent. Args: diff --git a/memory/storage/sqlite_ltm.py b/memory/storage/sqlite_ltm.py index 81b588a..1f661fb 100644 --- a/memory/storage/sqlite_ltm.py +++ b/memory/storage/sqlite_ltm.py @@ -555,7 +555,7 @@ def store_batch(self, memory_entries: List[Dict[str, Any]]) -> bool: logger.error("Unexpected error storing batch of memories: %s", str(e)) return False - def get(self, memory_id: str, agent_id: str, skip_validation: bool = False) -> Optional[Dict[str, Any]]: + def get(self, memory_id: str, agent_id: str, skip_validation: bool = True) -> Optional[Dict[str, Any]]: """Retrieve a memory by ID. Args: @@ -711,7 +711,7 @@ def get_by_timerange( end_time: Union[float, int, str], agent_id: str = None, limit: int = 100, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[Dict[str, Any]]: """Retrieve memories within a time range. @@ -800,7 +800,7 @@ def get_by_importance( min_importance: float = 0.0, max_importance: float = 1.0, limit: int = 100, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[Dict[str, Any]]: """Retrieve memories by importance score. @@ -858,7 +858,7 @@ def get_most_similar( query_vector: List[float], top_k: int = 10, agent_id: str = None, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[Tuple[Dict[str, Any], float]]: """Retrieve memories most similar to the query vector. @@ -940,7 +940,7 @@ def search_similar( k: int = 5, memory_type: Optional[str] = None, agent_id: str = None, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[Dict[str, Any]]: """Search for memories with similar embeddings. @@ -1161,7 +1161,7 @@ def get_size(self) -> int: logger.error("Unexpected error calculating memory size: %s", str(e)) return 0 - def get_all(self, agent_id: str = None, limit: int = 1000, skip_validation: bool = False) -> List[Dict[str, Any]]: + def get_all(self, agent_id: str = None, limit: int = 1000, skip_validation: bool = True) -> List[Dict[str, Any]]: """Get all memories for the agent. Args: @@ -1342,7 +1342,7 @@ def search_by_step_range( start_step: int, end_step: int, memory_type: Optional[str] = None, - skip_validation: bool = False, + skip_validation: bool = True, ) -> List[Dict[str, Any]]: """Search for memories within a specific step range. diff --git a/tests/api/test_models.py b/tests/api/test_models.py new file mode 100644 index 0000000..be8cbd4 --- /dev/null +++ b/tests/api/test_models.py @@ -0,0 +1,236 @@ +"""Unit tests for memory API models.""" + +import pytest +from memory.api.models import AgentState, ActionData, ActionResult + + +class TestAgentState: + """Test suite for the AgentState class.""" + + def test_initialization(self): + """Test basic initialization of AgentState.""" + state = AgentState( + agent_id="agent-1", + step_number=42, + health=0.8, + reward=10.5, + position_x=1.0, + position_y=2.0, + position_z=3.0, + resource_level=0.7, + extra_data={"inventory": ["sword", "shield"]} + ) + + assert state.agent_id == "agent-1" + assert state.step_number == 42 + assert state.health == 0.8 + assert state.reward == 10.5 + assert state.position_x == 1.0 + assert state.position_y == 2.0 + assert state.position_z == 3.0 + assert state.resource_level == 0.7 + assert state.extra_data == {"inventory": ["sword", "shield"]} + + def test_initialization_minimal(self): + """Test initialization with only required fields.""" + state = AgentState(agent_id="agent-1", step_number=10) + + assert state.agent_id == "agent-1" + assert state.step_number == 10 + assert state.health is None + assert state.reward is None + assert state.position_x is None + assert state.position_y is None + assert state.position_z is None + assert state.resource_level is None + assert state.extra_data == {} + + def test_as_dict_with_all_fields(self): + """Test as_dict method with all fields populated.""" + state = AgentState( + agent_id="agent-1", + step_number=42, + health=0.8, + reward=10.5, + position_x=1.0, + position_y=2.0, + position_z=3.0, + resource_level=0.7, + extra_data={"inventory": ["sword", "shield"]} + ) + + state_dict = state.as_dict() + + assert state_dict["agent_id"] == "agent-1" + assert state_dict["step_number"] == 42 + assert state_dict["health"] == 0.8 + assert state_dict["reward"] == 10.5 + assert state_dict["position_x"] == 1.0 + assert state_dict["position_y"] == 2.0 + assert state_dict["position_z"] == 3.0 + assert state_dict["resource_level"] == 0.7 + assert state_dict["extra_data"] == {"inventory": ["sword", "shield"]} + + def test_as_dict_with_none_values(self): + """Test as_dict method excludes None values.""" + state = AgentState( + agent_id="agent-1", + step_number=42, + health=None, + reward=None, + position_x=None, + position_y=None, + position_z=None, + resource_level=None + ) + + state_dict = state.as_dict() + + assert "agent_id" in state_dict + assert "step_number" in state_dict + assert "health" not in state_dict + assert "reward" not in state_dict + assert "position_x" not in state_dict + assert "position_y" not in state_dict + assert "position_z" not in state_dict + assert "resource_level" not in state_dict + assert "extra_data" not in state_dict + + def test_as_dict_with_empty_extra_data(self): + """Test as_dict method excludes empty extra_data.""" + state = AgentState( + agent_id="agent-1", + step_number=42, + health=0.8, + extra_data={} + ) + + state_dict = state.as_dict() + + assert "agent_id" in state_dict + assert "step_number" in state_dict + assert "health" in state_dict + assert "extra_data" not in state_dict + + +class TestActionData: + """Test suite for the ActionData class.""" + + def test_initialization(self): + """Test basic initialization of ActionData.""" + action_data = ActionData( + action_type="move", + action_params={"direction": "north", "distance": 2}, + state_before={"position": [0, 0], "health": 1.0}, + state_after={"position": [0, 2], "health": 0.9}, + reward=5.0, + execution_time=0.25, + step_number=42 + ) + + assert action_data.action_type == "move" + assert action_data.action_params == {"direction": "north", "distance": 2} + assert action_data.state_before == {"position": [0, 0], "health": 1.0} + assert action_data.state_after == {"position": [0, 2], "health": 0.9} + assert action_data.reward == 5.0 + assert action_data.execution_time == 0.25 + assert action_data.step_number == 42 + + def test_initialization_default_values(self): + """Test initialization with default values.""" + action_data = ActionData( + action_type="move", + state_before={"position": [0, 0], "health": 1.0}, + state_after={"position": [0, 2], "health": 0.9}, + execution_time=0.25, + step_number=42 + ) + + assert action_data.action_type == "move" + assert action_data.action_params == {} + assert action_data.reward == 0.0 + + def test_get_state_difference_numeric(self): + """Test get_state_difference method with numeric values.""" + action_data = ActionData( + action_type="move", + state_before={"health": 1.0, "position_x": 10, "energy": 100}, + state_after={"health": 0.8, "position_x": 15, "energy": 90}, + execution_time=0.25, + step_number=42 + ) + + diff = action_data.get_state_difference() + + assert "health" in diff + assert "position_x" in diff + assert "energy" in diff + assert pytest.approx(diff["health"]) == -0.2 + assert diff["position_x"] == 5 + assert diff["energy"] == -10 + + def test_get_state_difference_non_numeric(self): + """Test get_state_difference method ignores non-numeric values.""" + action_data = ActionData( + action_type="move", + state_before={ + "health": 1.0, + "position": [0, 0], + "inventory": ["sword"] + }, + state_after={ + "health": 0.8, + "position": [0, 2], + "inventory": ["sword", "shield"] + }, + execution_time=0.25, + step_number=42 + ) + + diff = action_data.get_state_difference() + + assert "health" in diff + assert pytest.approx(diff["health"]) == -0.2 + assert "position" not in diff + assert "inventory" not in diff + + def test_get_state_difference_missing_keys(self): + """Test get_state_difference method with keys present in only one state.""" + action_data = ActionData( + action_type="move", + state_before={"health": 1.0, "energy": 100}, + state_after={"health": 0.8, "mana": 50}, + execution_time=0.25, + step_number=42 + ) + + diff = action_data.get_state_difference() + + assert "health" in diff + assert pytest.approx(diff["health"]) == -0.2 + assert "energy" not in diff + assert "mana" not in diff + + +class TestActionResult: + """Test suite for the ActionResult class.""" + + def test_initialization(self): + """Test basic initialization of ActionResult.""" + result = ActionResult( + action_type="attack", + params={"target": "enemy-1", "weapon": "sword"}, + reward=10.0 + ) + + assert result.action_type == "attack" + assert result.params == {"target": "enemy-1", "weapon": "sword"} + assert result.reward == 10.0 + + def test_initialization_default_values(self): + """Test initialization with default values.""" + result = ActionResult(action_type="observe") + + assert result.action_type == "observe" + assert result.params == {} + assert result.reward == 0.0 \ No newline at end of file diff --git a/tests/api/test_types.py b/tests/api/test_types.py new file mode 100644 index 0000000..bb4e3f2 --- /dev/null +++ b/tests/api/test_types.py @@ -0,0 +1,397 @@ +"""Unit tests for memory API types.""" + +import pytest +from typing import Dict, List, Any, Optional +from memory.api.types import ( + MemoryMetadata, + MemoryEmbeddings, + MemoryEntry, + MemoryChangeRecord, + MemoryTypeDistribution, + MemoryStatistics, + SimilaritySearchResult, + ConfigUpdate, + QueryResult, + MemoryStore +) + + +class TestMemoryTypesStructures: + """Test suite for memory type structures.""" + + def test_memory_metadata_structure(self): + """Test MemoryMetadata structure with all fields.""" + metadata: MemoryMetadata = { + "creation_time": 1649879872.123, + "last_access_time": 1649879900.456, + "compression_level": 0, + "importance_score": 0.8, + "retrieval_count": 5, + "memory_type": "state", + "current_tier": "stm", + "checksum": "abc123" + } + + assert metadata["creation_time"] == 1649879872.123 + assert metadata["last_access_time"] == 1649879900.456 + assert metadata["compression_level"] == 0 + assert metadata["importance_score"] == 0.8 + assert metadata["retrieval_count"] == 5 + assert metadata["memory_type"] == "state" + assert metadata["current_tier"] == "stm" + assert metadata["checksum"] == "abc123" + + def test_memory_metadata_partial(self): + """Test MemoryMetadata with partial fields (total=False).""" + # This should work because MemoryMetadata has total=False + metadata: MemoryMetadata = { + "creation_time": 1649879872.123, + "importance_score": 0.8, + "memory_type": "state", + } + + assert metadata["creation_time"] == 1649879872.123 + assert metadata["importance_score"] == 0.8 + assert metadata["memory_type"] == "state" + assert "compression_level" not in metadata + + def test_memory_embeddings_structure(self): + """Test MemoryEmbeddings structure with all fields.""" + embeddings: MemoryEmbeddings = { + "full_vector": [0.1, 0.2, 0.3, 0.4], + "compressed_vector": [0.15, 0.25, 0.35], + "abstract_vector": [0.2, 0.3] + } + + assert embeddings["full_vector"] == [0.1, 0.2, 0.3, 0.4] + assert embeddings["compressed_vector"] == [0.15, 0.25, 0.35] + assert embeddings["abstract_vector"] == [0.2, 0.3] + + def test_memory_embeddings_partial(self): + """Test MemoryEmbeddings with partial fields (total=False).""" + # This should work because MemoryEmbeddings has total=False + embeddings: MemoryEmbeddings = { + "full_vector": [0.1, 0.2, 0.3, 0.4], + } + + assert embeddings["full_vector"] == [0.1, 0.2, 0.3, 0.4] + assert "compressed_vector" not in embeddings + assert "abstract_vector" not in embeddings + + def test_memory_entry_structure(self): + """Test MemoryEntry structure with all fields.""" + memory: MemoryEntry = { + "memory_id": "mem_12345", + "agent_id": "agent_001", + "step_number": 42, + "timestamp": 1649879872.123, + "contents": {"observation": "User asked about weather"}, + "metadata": { + "importance_score": 0.8, + "memory_type": "interaction", + "creation_time": 1649879872.123, + "retrieval_count": 0, + "compression_level": 0, + "current_tier": "stm", + "last_access_time": 1649879872.123, + "checksum": "abc123" + }, + "embeddings": { + "full_vector": [0.1, 0.2, 0.3, 0.4] + } + } + + assert memory["memory_id"] == "mem_12345" + assert memory["agent_id"] == "agent_001" + assert memory["step_number"] == 42 + assert memory["timestamp"] == 1649879872.123 + assert memory["contents"]["observation"] == "User asked about weather" + assert memory["metadata"]["importance_score"] == 0.8 + assert memory["embeddings"]["full_vector"] == [0.1, 0.2, 0.3, 0.4] + + def test_memory_entry_without_embeddings(self): + """Test MemoryEntry without embeddings (embeddings is Optional).""" + memory: MemoryEntry = { + "memory_id": "mem_12345", + "agent_id": "agent_001", + "step_number": 42, + "timestamp": 1649879872.123, + "contents": {"observation": "User asked about weather"}, + "metadata": { + "importance_score": 0.8, + "memory_type": "interaction", + "creation_time": 1649879872.123, + "retrieval_count": 0, + "compression_level": 0, + "current_tier": "stm", + "last_access_time": 1649879872.123, + "checksum": "abc123" + }, + "embeddings": None + } + + assert memory["memory_id"] == "mem_12345" + assert memory["embeddings"] is None + + def test_memory_change_record_structure(self): + """Test MemoryChangeRecord structure.""" + change_record: MemoryChangeRecord = { + "memory_id": "mem_12345", + "step_number": 42, + "timestamp": 1649879872.123, + "previous_value": 10, + "new_value": 20 + } + + assert change_record["memory_id"] == "mem_12345" + assert change_record["step_number"] == 42 + assert change_record["timestamp"] == 1649879872.123 + assert change_record["previous_value"] == 10 + assert change_record["new_value"] == 20 + + def test_memory_change_record_with_none_previous(self): + """Test MemoryChangeRecord with None previous value.""" + change_record: MemoryChangeRecord = { + "memory_id": "mem_12345", + "step_number": 42, + "timestamp": 1649879872.123, + "previous_value": None, + "new_value": 20 + } + + assert change_record["memory_id"] == "mem_12345" + assert change_record["previous_value"] is None + assert change_record["new_value"] == 20 + + def test_memory_type_distribution_structure(self): + """Test MemoryTypeDistribution structure with all fields.""" + distribution: MemoryTypeDistribution = { + "state": 10, + "action": 20, + "interaction": 30 + } + + assert distribution["state"] == 10 + assert distribution["action"] == 20 + assert distribution["interaction"] == 30 + + def test_memory_type_distribution_partial(self): + """Test MemoryTypeDistribution with partial fields (total=False).""" + # This should work because MemoryTypeDistribution has total=False + distribution: MemoryTypeDistribution = { + "state": 10, + "action": 20, + } + + assert distribution["state"] == 10 + assert distribution["action"] == 20 + assert "interaction" not in distribution + + def test_memory_statistics_structure(self): + """Test MemoryStatistics structure.""" + stats: MemoryStatistics = { + "total_memories": 100, + "stm_count": 20, + "im_count": 30, + "ltm_count": 50, + "memory_type_distribution": { + "state": 40, + "action": 30, + "interaction": 30 + }, + "last_maintenance_time": 1649879872.123, + "insert_count_since_maintenance": 5 + } + + assert stats["total_memories"] == 100 + assert stats["stm_count"] == 20 + assert stats["im_count"] == 30 + assert stats["ltm_count"] == 50 + assert stats["memory_type_distribution"]["state"] == 40 + assert stats["last_maintenance_time"] == 1649879872.123 + assert stats["insert_count_since_maintenance"] == 5 + + def test_memory_statistics_with_none_time(self): + """Test MemoryStatistics with None for last_maintenance_time.""" + stats: MemoryStatistics = { + "total_memories": 100, + "stm_count": 20, + "im_count": 30, + "ltm_count": 50, + "memory_type_distribution": { + "state": 40, + "action": 30, + "interaction": 30 + }, + "last_maintenance_time": None, + "insert_count_since_maintenance": 5 + } + + assert stats["total_memories"] == 100 + assert stats["last_maintenance_time"] is None + + def test_similarity_search_result_structure(self): + """Test SimilaritySearchResult structure.""" + search_result: SimilaritySearchResult = { + "memory_id": "mem_12345", + "agent_id": "agent_001", + "step_number": 42, + "timestamp": 1649879872.123, + "contents": {"observation": "User asked about weather"}, + "metadata": { + "importance_score": 0.8, + "memory_type": "interaction", + "creation_time": 1649879872.123, + "retrieval_count": 0, + "compression_level": 0, + "current_tier": "stm", + "last_access_time": 1649879872.123, + "checksum": "abc123" + }, + "embeddings": { + "full_vector": [0.1, 0.2, 0.3, 0.4] + }, + "_similarity_score": 0.95 + } + + assert search_result["memory_id"] == "mem_12345" + assert search_result["contents"]["observation"] == "User asked about weather" + assert search_result["_similarity_score"] == 0.95 + + def test_config_update_structure(self): + """Test ConfigUpdate structure (Dict[str, Any]).""" + config_update: ConfigUpdate = { + "stm_config": {"memory_limit": 1000}, + "im_config": {"memory_limit": 5000}, + "enable_memory_hooks": True + } + + assert config_update["stm_config"]["memory_limit"] == 1000 + assert config_update["im_config"]["memory_limit"] == 5000 + assert config_update["enable_memory_hooks"] is True + + def test_query_result_structure(self): + """Test QueryResult structure.""" + query_result: QueryResult = { + "memory_id": "mem_12345", + "agent_id": "agent_001", + "step_number": 42, + "timestamp": 1649879872.123, + "contents": {"observation": "User asked about weather"}, + "metadata": { + "importance_score": 0.8, + "memory_type": "interaction", + "creation_time": 1649879872.123, + "retrieval_count": 0, + "compression_level": 0, + "current_tier": "stm", + "last_access_time": 1649879872.123, + "checksum": "abc123" + } + } + + assert query_result["memory_id"] == "mem_12345" + assert query_result["agent_id"] == "agent_001" + assert query_result["step_number"] == 42 + assert query_result["contents"]["observation"] == "User asked about weather" + assert query_result["metadata"]["importance_score"] == 0.8 + + +class TestMemoryStoreProtocol: + """Test suite for MemoryStore protocol implementation.""" + + class MockMemoryStore: + """Mock implementation of MemoryStore protocol for testing.""" + + def __init__(self): + self.memories = {} + + def get(self, memory_id: str) -> Optional[MemoryEntry]: + """Get a memory by ID.""" + return self.memories.get(memory_id) + + def get_recent(self, count: int, memory_type: Optional[str] = None) -> List[MemoryEntry]: + """Get recent memories.""" + return [] + + def get_by_step_range( + self, start_step: int, end_step: int, memory_type: Optional[str] = None + ) -> List[MemoryEntry]: + """Get memories in a step range.""" + return [] + + def get_by_attributes( + self, attributes: Dict[str, Any], memory_type: Optional[str] = None + ) -> List[MemoryEntry]: + """Get memories matching attributes.""" + return [] + + def search_by_vector( + self, vector: List[float], k: int = 5, memory_type: Optional[str] = None + ) -> List[MemoryEntry]: + """Search memories by vector similarity.""" + return [] + + def search_by_content( + self, content_query: Dict[str, Any], k: int = 5 + ) -> List[MemoryEntry]: + """Search memories by content.""" + return [] + + def contains(self, memory_id: str) -> bool: + """Check if a memory exists.""" + return memory_id in self.memories + + def update(self, memory: MemoryEntry) -> bool: + """Update a memory.""" + self.memories[memory["memory_id"]] = memory + return True + + def count(self) -> int: + """Count memories.""" + return len(self.memories) + + def count_by_type(self) -> Dict[str, int]: + """Count memories by type.""" + return {"state": 0, "action": 0, "interaction": 0} + + def clear(self) -> bool: + """Clear all memories.""" + self.memories.clear() + return True + + def test_memory_store_protocol_compliance(self): + """Test that MockMemoryStore complies with MemoryStore protocol.""" + store = self.MockMemoryStore() + + # This will fail if the MockMemoryStore doesn't properly implement MemoryStore + assert isinstance(store, MemoryStore) + + # Test basic functionality + test_memory: MemoryEntry = { + "memory_id": "test-mem", + "agent_id": "test-agent", + "step_number": 1, + "timestamp": 1000.0, + "contents": {"data": "test content"}, + "metadata": { + "importance_score": 0.5, + "memory_type": "state", + "creation_time": 1000.0, + "compression_level": 0, + "current_tier": "stm", + "last_access_time": 1000.0, + "retrieval_count": 0, + "checksum": "test" + }, + "embeddings": None + } + + assert store.update(test_memory) is True + assert store.contains("test-mem") is True + assert store.count() == 1 + retrieved = store.get("test-mem") + assert retrieved is not None + assert retrieved["memory_id"] == "test-mem" + assert store.clear() is True + assert store.count() == 0 \ No newline at end of file diff --git a/validation/search/performance_test_attribute.py b/validation/search/attribute/performance_test_attribute.py similarity index 99% rename from validation/search/performance_test_attribute.py rename to validation/search/attribute/performance_test_attribute.py index cccec15..70c4cbc 100644 --- a/validation/search/performance_test_attribute.py +++ b/validation/search/attribute/performance_test_attribute.py @@ -21,7 +21,7 @@ import logging # Add project root to path -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) # Suppress checksum validation warnings - important to set this before imports logging.getLogger('memory.storage.redis_im').setLevel(logging.ERROR) diff --git a/validation/search/attribute/validate_attribute.py b/validation/search/attribute/validate_attribute.py new file mode 100644 index 0000000..4aa41e0 --- /dev/null +++ b/validation/search/attribute/validate_attribute.py @@ -0,0 +1,822 @@ +""" +Validation script for the Attribute Search Strategy. + +This script loads a predefined memory system and tests various scenarios +of attribute-based searching to verify the strategy works correctly. +""" + +import os +import sys +from typing import Any, Dict, List, Set + +# Add the project root to the path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) + +from validation.demo_utils import ( + create_memory_system, + log_print, + pretty_print_memories, + setup_logging, +) +from memory.search.strategies.attribute import AttributeSearchStrategy + +# Constants +AGENT_ID = "test-agent-attribute-search" +MEMORY_SAMPLE = os.path.join("memory_samples", "attribute_validation_memory.json") + +# Dictionary mapping memory IDs to their checksums for easier reference +MEMORY_CHECKSUMS = { + "meeting-123456-1": "0eb0f81d07276f08e05351a604d3c994564fedee3a93329e318186da517a3c56", + "meeting-123456-3": "f6ab36930459e74a52fdf21fb96a84241ccae3f6987365a21f9a17d84c5dae1e", + "meeting-123456-6": "ffa0ee60ebaec5574358a02d1857823e948519244e366757235bf755c888a87f", + "meeting-123456-9": "9214ebc2d11877665b32771bd3c080414d9519b435ec3f6c19cc5f337bb0ba90", + "meeting-123456-11": "ad2e7c963751beb1ebc1c9b84ecb09ec3ccdef14f276cd14bbebad12d0f9b0df", + "task-123456-2": "e0f7deb6929a17f65f56e5b03e16067c8bb65649fd2745f842aca7af701c9cac", + "task-123456-7": "1d23b6683acd8c3863cb2f2010fe3df2c3e69a2d94c7c4757a291d4872066cfd", + "task-123456-10": "f3c73b06d6399ed30ea9d9ad7c711a86dd58154809cc05497f8955425ec6dc67", + "note-123456-4": "1e9e265e75c2ef678dfd0de0ab5c801f845daa48a90a48bb02ee85148ccc3470", + "note-123456-8": "169c452e368fd62e3c0cf5ce7731769ed46ab6ae73e5048e0c3a7caaa66fba46", + "contact-123456-5": "496d09718bbc8ae669dffdd782ed5b849fdbb1a57e3f7d07e61807b10e650092", +} + + +def get_checksums_for_memory_ids(memory_ids: List[str]) -> Set[str]: + """Helper function to get checksums from memory IDs.""" + return { + MEMORY_CHECKSUMS[memory_id] + for memory_id in memory_ids + if memory_id in MEMORY_CHECKSUMS + } + + +def run_test( + search_strategy: AttributeSearchStrategy, + test_name: str, + query: Any, + agent_id: str, + limit: int = 10, + metadata_filter: Dict[str, Any] = None, + tier: str = None, + content_fields: List[str] = None, + metadata_fields: List[str] = None, + match_all: bool = False, + case_sensitive: bool = False, + use_regex: bool = False, + scoring_method: str = None, + expected_checksums: Set[str] = None, + expected_memory_ids: List[str] = None, +) -> Dict[str, Any]: + """Run a test case and return the results.""" + log_print(logger, f"\n=== Test: {test_name} ===") + + if isinstance(query, dict): + log_print(logger, f"Query (dict): {query}") + else: + log_print(logger, f"Query: '{query}'") + + log_print( + logger, + f"Match All: {match_all}, Case Sensitive: {case_sensitive}, Use Regex: {use_regex}", + ) + + if metadata_filter: + log_print(logger, f"Metadata Filter: {metadata_filter}") + + if tier: + log_print(logger, f"Tier: {tier}") + + if content_fields: + log_print(logger, f"Content Fields: {content_fields}") + + if metadata_fields: + log_print(logger, f"Metadata Fields: {metadata_fields}") + + if scoring_method: + log_print(logger, f"Scoring Method: {scoring_method}") + + # If expected_memory_ids is provided, convert to checksums + if expected_memory_ids and not expected_checksums: + expected_checksums = get_checksums_for_memory_ids(expected_memory_ids) + log_print( + logger, + f"Expecting {len(expected_checksums)} memories from specified memory IDs", + ) + + results = search_strategy.search( + query=query, + agent_id=agent_id, + limit=limit, + metadata_filter=metadata_filter, + tier=tier, + content_fields=content_fields, + metadata_fields=metadata_fields, + match_all=match_all, + case_sensitive=case_sensitive, + use_regex=use_regex, + scoring_method=scoring_method, + ) + + log_print(logger, f"Found {len(results)} results") + pretty_print_memories(results, f"Results for {test_name}", logger) + + # If we have scoring method, print the scores for comparison + if scoring_method and results: + log_print(logger, f"\nScores using {scoring_method} scoring method:") + for idx, result in enumerate(results[:5]): # Show scores for top 5 results + score = result.get("metadata", {}).get("attribute_score", 0) + memory_id = result.get("memory_id", result.get("id", f"Result {idx+1}")) + log_print(logger, f" {memory_id}: {score:.4f}") + + # Track test status + test_passed = True + + # Validate against expected checksums if provided + if expected_checksums: + result_checksums = { + result.get("metadata", {}).get("checksum", "") for result in results + } + missing_checksums = expected_checksums - result_checksums + unexpected_checksums = result_checksums - expected_checksums + + log_print(logger, f"\nValidation Results:") + if not missing_checksums and not unexpected_checksums: + log_print(logger, "All expected memories found. No unexpected memories.") + else: + if missing_checksums: + log_print(logger, f"Missing expected memories: {missing_checksums}") + test_passed = False + if unexpected_checksums: + log_print(logger, f"Found unexpected memories: {unexpected_checksums}") + test_passed = False + + log_print( + logger, + f"Expected: {len(expected_checksums)}, Found: {len(result_checksums)}, " + f"Missing: {len(missing_checksums)}, Unexpected: {len(unexpected_checksums)}", + ) + + return { + "results": results, + "test_name": test_name, + "passed": test_passed, + "has_validation": expected_checksums is not None + or expected_memory_ids is not None, + } + + +def validate_attribute_search(): + """Run validation tests for the attribute search strategy.""" + # Setup memory system + memory_system = create_memory_system( + logging_level="INFO", + memory_file=MEMORY_SAMPLE, + use_mock_redis=True, + ) + + # If memory system failed to load, exit + if not memory_system: + log_print(logger, "Failed to load memory system") + return + + # Setup search strategy + agent = memory_system.get_memory_agent(AGENT_ID) + search_strategy = AttributeSearchStrategy( + agent.stm_store, agent.im_store, agent.ltm_store + ) + + # Print strategy info + log_print(logger, f"Testing search strategy: {search_strategy.name()}") + log_print(logger, f"Description: {search_strategy.description()}") + + # Track test results + test_results = [] + + # Test 1: Basic content search + test_results.append( + run_test( + search_strategy, + "Basic Content Search", + "meeting", + AGENT_ID, + content_fields=["content.content"], + metadata_filter={"content.metadata.type": "meeting"}, + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + # Test 2: Case sensitive search + test_results.append( + run_test( + search_strategy, + "Case Sensitive Search", + "Meeting", + AGENT_ID, + case_sensitive=True, + content_fields=["content.content"], # Only search in content + metadata_filter={ + "content.metadata.type": "meeting" + }, # Only get meeting-type memories + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + ], + ) + ) + + # Test 3: Search by metadata type + test_results.append( + run_test( + search_strategy, + "Search by Metadata Type", + {"metadata": {"type": "meeting"}}, + AGENT_ID, + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + # Test 4: Search with match_all + test_results.append( + run_test( + search_strategy, + "Search with Match All", + { + "content": "meeting", + "metadata": {"type": "meeting", "importance": "high"}, + }, + AGENT_ID, + match_all=True, + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + # Test 5: Search specific memory tier + test_results.append( + run_test( + search_strategy, + "Search in STM Tier Only", + "meeting", + AGENT_ID, + tier="stm", + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + ], + ) + ) + + # Test 6: Search with regex + test_results.append( + run_test( + search_strategy, + "Regex Search", + "secur.*patch", + AGENT_ID, + use_regex=True, + expected_memory_ids=[ + "note-123456-4", + ], + ) + ) + + # Test 7: Search with metadata filter + test_results.append( + run_test( + search_strategy, + "Search with Metadata Filter", + "meeting", + AGENT_ID, + metadata_filter={"content.metadata.importance": "high"}, + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + # Test 8: Search in specific content fields + test_results.append( + run_test( + search_strategy, + "Search in Specific Content Fields", + "project", + AGENT_ID, + content_fields=["content.content"], + expected_memory_ids=[ + "meeting-123456-1", + "contact-123456-5", + ], + ) + ) + + # Test 9: Search in specific metadata fields + test_results.append( + run_test( + search_strategy, + "Search in Specific Metadata Fields", + "project", + AGENT_ID, + metadata_fields=["content.metadata.tags"], + expected_memory_ids=[ + "meeting-123456-1", + "contact-123456-5", + ], + ) + ) + + # Test 10: Search with complex query and filters + test_results.append( + run_test( + search_strategy, + "Complex Search", + {"content": "security", "metadata": {"importance": "high"}}, + AGENT_ID, + metadata_filter={"content.metadata.source": "email"}, + match_all=True, + expected_memory_ids=[ + "note-123456-4", + ], + ) + ) + + # Test 11: Empty query handling - string + test_results.append( + run_test( + search_strategy, + "Empty String Query", + "", + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 12: Empty query handling - dict + test_results.append( + run_test( + search_strategy, + "Empty Dict Query", + {}, + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 13: Numeric value search + test_results.append( + run_test( + search_strategy, + "Numeric Value Search", + 42, + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 14: Boolean value search + test_results.append( + run_test( + search_strategy, + "Boolean Value Search", + {"metadata": {"completed": True}}, + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 15: Type conversion - searching string with numeric + test_results.append( + run_test( + search_strategy, + "Type Conversion - String Field with Numeric", + 123, + AGENT_ID, + content_fields=["content.content"], + expected_memory_ids=[], + ) + ) + + # Test 16: Invalid regex pattern handling + test_results.append( + run_test( + search_strategy, + "Invalid Regex Pattern", + "[unclosed-bracket", + AGENT_ID, + use_regex=True, + expected_memory_ids=[], + ) + ) + + # Test 17: Array field partial matching + test_results.append( + run_test( + search_strategy, + "Array Field Partial Matching", + "dev", + AGENT_ID, + metadata_fields=["content.metadata.tags"], + expected_memory_ids=[ + "meeting-123456-3", + "task-123456-10", + ], + ) + ) + + # Test 18: Special characters in search + test_results.append( + run_test( + search_strategy, + "Special Characters in Search", + "meeting+notes", + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 19: Multi-tier search + test_results.append( + run_test( + search_strategy, + "Multi-Tier Search", + "important", + AGENT_ID, + expected_memory_ids=[], + ) + ) + + # Test 20: Large result set limiting + test_results.append( + run_test( + search_strategy, + "Large Result Set Limiting", + "a", # Common letter to match many memories + AGENT_ID, + limit=3, # Only show top 3 results + expected_memory_ids=[ + "meeting-123456-1", # Contains "about" and "allocation" - shorter content with multiple 'a's + "meeting-123456-6", # Contains "about" and "roadmap" - shorter content with multiple 'a's + "meeting-123456-3", # Contains "authentication" and "team" - longer content with fewer 'a's + ], + ) + ) + + # ===== New tests for scoring methods ===== + log_print(logger, "\n=== SCORING METHOD COMPARISON TESTS ===") + + # Test 21: Comparing scoring methods on the same query + test_query = "meeting" + log_print(logger, f"\nComparing scoring methods for query: '{test_query}'") + + # Try each scoring method and collect results + test_results.append( + run_test( + search_strategy, + "Default Length Ratio Scoring", + test_query, + AGENT_ID, + limit=5, + scoring_method="length_ratio", + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "Term Frequency Scoring", + test_query, + AGENT_ID, + limit=5, + scoring_method="term_frequency", + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "BM25 Scoring", + test_query, + AGENT_ID, + limit=5, + scoring_method="bm25", + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "Binary Scoring", + test_query, + AGENT_ID, + limit=5, + scoring_method="binary", + expected_memory_ids=[ + "meeting-123456-1", + "meeting-123456-3", + "meeting-123456-6", + "meeting-123456-9", + "meeting-123456-11", + ], + ) + ) + + # Test 22: Testing scoring on a document with repeated terms + test_with_repetition_query = "security" # Look for security-related memories + log_print( + logger, + f"\nComparing scoring methods for query with potential term repetition: '{test_with_repetition_query}'", + ) + + # Test with default length ratio scoring + test_results.append( + run_test( + search_strategy, + "Default Scoring with Term Repetition", + test_with_repetition_query, + AGENT_ID, + limit=5, + expected_memory_ids=[ + "note-123456-4", + "note-123456-8", + "meeting-123456-11", + ], + ) + ) + + # Test with term frequency scoring - should favor documents with more occurrences + test_results.append( + run_test( + search_strategy, + "Term Frequency with Term Repetition", + test_with_repetition_query, + AGENT_ID, + limit=5, + scoring_method="term_frequency", + expected_memory_ids=[ + "note-123456-4", + "note-123456-8", + "meeting-123456-11", + ], + ) + ) + + # Test with BM25 scoring - balances term frequency and document length + test_results.append( + run_test( + search_strategy, + "BM25 with Term Repetition", + test_with_repetition_query, + AGENT_ID, + limit=5, + scoring_method="bm25", + expected_memory_ids=[ + "note-123456-4", + "note-123456-8", + "meeting-123456-11", + ], + ) + ) + + # Test 23: Testing with a specialized search strategy for each method + log_print( + logger, + "\nTesting with dedicated search strategy instances for each scoring method", + ) + + # Create specialized strategy instances + term_freq_strategy = AttributeSearchStrategy( + agent.stm_store, + agent.im_store, + agent.ltm_store, + scoring_method="term_frequency", + ) + + bm25_strategy = AttributeSearchStrategy( + agent.stm_store, + agent.im_store, + agent.ltm_store, + scoring_method="bm25", + ) + + # Run test with specialized strategies + test_results.append( + run_test( + term_freq_strategy, + "Using Term Frequency Strategy Instance", + "project", + AGENT_ID, + limit=5, + expected_memory_ids=[ + "meeting-123456-1", + "contact-123456-5", + ], + ) + ) + + test_results.append( + run_test( + bm25_strategy, + "Using BM25 Strategy Instance", + "project", + AGENT_ID, + limit=5, + expected_memory_ids=[ + "meeting-123456-1", + "contact-123456-5", + ], + ) + ) + + # Test 24: Testing with a long document vs short document comparison + # Change from "detailed" (no matches) to "authentication system" (appears in memories of different lengths) + long_doc_query = "authentication system" + log_print( + logger, + f"\nComparing scoring methods for long vs short document query: '{long_doc_query}'", + ) + + # Compare each scoring method + test_results.append( + run_test( + search_strategy, + "Length Ratio for Long Documents", + long_doc_query, + AGENT_ID, + limit=5, + scoring_method="length_ratio", + expected_memory_ids=[ + "meeting-123456-3", + "task-123456-7", + "task-123456-10", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "Term Frequency for Long Documents", + long_doc_query, + AGENT_ID, + limit=5, + scoring_method="term_frequency", + expected_memory_ids=[ + "meeting-123456-3", + "task-123456-7", + "task-123456-10", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "BM25 for Long Documents", + long_doc_query, + AGENT_ID, + limit=5, + scoring_method="bm25", + expected_memory_ids=[ + "meeting-123456-3", + "task-123456-7", + "task-123456-10", + ], + ) + ) + + # Test 25: Testing with a query that matches varying document length and context + varying_length_query = "documentation" + log_print( + logger, + f"\nComparing scoring methods for documents of varying lengths: '{varying_length_query}'", + ) + + test_results.append( + run_test( + search_strategy, + "Length Ratio for Documentation Query", + varying_length_query, + AGENT_ID, + limit=5, + scoring_method="length_ratio", + expected_memory_ids=[ + "task-123456-2", + "task-123456-7", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "Term Frequency for Documentation Query", + varying_length_query, + AGENT_ID, + limit=5, + scoring_method="term_frequency", + expected_memory_ids=[ + "task-123456-2", + "task-123456-7", + ], + ) + ) + + test_results.append( + run_test( + search_strategy, + "BM25 for Documentation Query", + varying_length_query, + AGENT_ID, + limit=5, + scoring_method="bm25", + expected_memory_ids=[ + "task-123456-2", + "task-123456-7", + ], + ) + ) + + # Display validation summary + log_print(logger, "\n\n=== VALIDATION SUMMARY ===") + log_print(logger, "-" * 80) + log_print( + logger, + "| {:<40} | {:<20} | {:<20} |".format( + "Test Name", "Status", "Validation Status" + ), + ) + log_print(logger, "-" * 80) + + for result in test_results: + status = "PASS" if result["passed"] else "FAIL" + validation_status = status if result["has_validation"] else "N/A" + log_print( + logger, + "| {:<40} | {:<20} | {:<20} |".format( + result["test_name"][:40], status, validation_status + ), + ) + + log_print(logger, "-" * 80) + + # Calculate overall statistics + validated_tests = [t for t in test_results if t["has_validation"]] + passed_tests = [t for t in validated_tests if t["passed"]] + + if validated_tests: + success_rate = len(passed_tests) / len(validated_tests) * 100 + log_print(logger, f"\nValidated Tests: {len(validated_tests)}") + log_print(logger, f"Passed Tests: {len(passed_tests)}") + log_print(logger, f"Failed Tests: {len(validated_tests) - len(passed_tests)}") + log_print(logger, f"Success Rate: {success_rate:.2f}%") + else: + log_print(logger, "\nNo tests with validation criteria were run.") + + +if __name__ == "__main__": + # Setup logging + logger = setup_logging("validate_attribute_search") + log_print(logger, "Starting Attribute Search Strategy Validation") + + validate_attribute_search() + + log_print(logger, "Validation Complete") diff --git a/validation/search/validate_attribute.py b/validation/search/validate_attribute.py deleted file mode 100644 index f6bf69d..0000000 --- a/validation/search/validate_attribute.py +++ /dev/null @@ -1,725 +0,0 @@ -""" -Validation script for the Attribute Search Strategy. - -This script loads a predefined memory system and tests various scenarios -of attribute-based searching to verify the strategy works correctly. -""" - -import os -import sys -from typing import Any, Dict, List, Set - -# Add the project root to the path -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) - -from validation.demo_utils import ( - create_memory_system, - log_print, - pretty_print_memories, - setup_logging, -) -from memory.search.strategies.attribute import AttributeSearchStrategy - -# Constants -AGENT_ID = "test-agent-attribute-search" -MEMORY_SAMPLE = "attribute_validation_memory.json" - -# Dictionary mapping memory IDs to their checksums for easier reference -MEMORY_CHECKSUMS = { - "meeting-123456-1": "0eb0f81d07276f08e05351a604d3c994564fedee3a93329e318186da517a3c56", - "meeting-123456-3": "f6ab36930459e74a52fdf21fb96a84241ccae3f6987365a21f9a17d84c5dae1e", - "meeting-123456-6": "ffa0ee60ebaec5574358a02d1857823e948519244e366757235bf755c888a87f", - "meeting-123456-9": "9214ebc2d11877665b32771bd3c080414d9519b435ec3f6c19cc5f337bb0ba90", - "meeting-123456-11": "ad2e7c963751beb1ebc1c9b84ecb09ec3ccdef14f276cd14bbebad12d0f9b0df", - "task-123456-2": "e0f7deb6929a17f65f56e5b03e16067c8bb65649fd2745f842aca7af701c9cac", - "task-123456-7": "1d23b6683acd8c3863cb2f2010fe3df2c3e69a2d94c7c4757a291d4872066cfd", - "task-123456-10": "f3c73b06d6399ed30ea9d9ad7c711a86dd58154809cc05497f8955425ec6dc67", - "note-123456-4": "1e9e265e75c2ef678dfd0de0ab5c801f845daa48a90a48bb02ee85148ccc3470", - "note-123456-8": "169c452e368fd62e3c0cf5ce7731769ed46ab6ae73e5048e0c3a7caaa66fba46", - "contact-123456-5": "496d09718bbc8ae669dffdd782ed5b849fdbb1a57e3f7d07e61807b10e650092", -} - - -def get_checksums_for_memory_ids(memory_ids: List[str]) -> Set[str]: - """Helper function to get checksums from memory IDs.""" - return { - MEMORY_CHECKSUMS[memory_id] - for memory_id in memory_ids - if memory_id in MEMORY_CHECKSUMS - } - - -def run_test( - search_strategy: AttributeSearchStrategy, - test_name: str, - query: Any, - agent_id: str, - limit: int = 10, - metadata_filter: Dict[str, Any] = None, - tier: str = None, - content_fields: List[str] = None, - metadata_fields: List[str] = None, - match_all: bool = False, - case_sensitive: bool = False, - use_regex: bool = False, - scoring_method: str = None, - expected_checksums: Set[str] = None, - expected_memory_ids: List[str] = None, -) -> Dict[str, Any]: - """Run a test case and return the results.""" - log_print(logger, f"\n=== Test: {test_name} ===") - - if isinstance(query, dict): - log_print(logger, f"Query (dict): {query}") - else: - log_print(logger, f"Query: '{query}'") - - log_print( - logger, - f"Match All: {match_all}, Case Sensitive: {case_sensitive}, Use Regex: {use_regex}", - ) - - if metadata_filter: - log_print(logger, f"Metadata Filter: {metadata_filter}") - - if tier: - log_print(logger, f"Tier: {tier}") - - if content_fields: - log_print(logger, f"Content Fields: {content_fields}") - - if metadata_fields: - log_print(logger, f"Metadata Fields: {metadata_fields}") - - if scoring_method: - log_print(logger, f"Scoring Method: {scoring_method}") - - # If expected_memory_ids is provided, convert to checksums - if expected_memory_ids and not expected_checksums: - expected_checksums = get_checksums_for_memory_ids(expected_memory_ids) - log_print( - logger, - f"Expecting {len(expected_checksums)} memories from specified memory IDs", - ) - - results = search_strategy.search( - query=query, - agent_id=agent_id, - limit=limit, - metadata_filter=metadata_filter, - tier=tier, - content_fields=content_fields, - metadata_fields=metadata_fields, - match_all=match_all, - case_sensitive=case_sensitive, - use_regex=use_regex, - scoring_method=scoring_method, - ) - - log_print(logger, f"Found {len(results)} results") - pretty_print_memories(results, f"Results for {test_name}", logger) - - # If we have scoring method, print the scores for comparison - if scoring_method and results: - log_print(logger, f"\nScores using {scoring_method} scoring method:") - for idx, result in enumerate(results[:5]): # Show scores for top 5 results - score = result.get("metadata", {}).get("attribute_score", 0) - memory_id = result.get("memory_id", result.get("id", f"Result {idx+1}")) - log_print(logger, f" {memory_id}: {score:.4f}") - - # Track test status - test_passed = True - - # Validate against expected checksums if provided - if expected_checksums: - result_checksums = { - result.get("metadata", {}).get("checksum", "") for result in results - } - missing_checksums = expected_checksums - result_checksums - unexpected_checksums = result_checksums - expected_checksums - - log_print(logger, f"\nValidation Results:") - if not missing_checksums and not unexpected_checksums: - log_print(logger, "All expected memories found. No unexpected memories.") - else: - if missing_checksums: - log_print(logger, f"Missing expected memories: {missing_checksums}") - test_passed = False - if unexpected_checksums: - log_print(logger, f"Found unexpected memories: {unexpected_checksums}") - test_passed = False - - log_print( - logger, - f"Expected: {len(expected_checksums)}, Found: {len(result_checksums)}, " - f"Missing: {len(missing_checksums)}, Unexpected: {len(unexpected_checksums)}", - ) - - return { - "results": results, - "test_name": test_name, - "passed": test_passed, - "has_validation": expected_checksums is not None - } - - -def validate_attribute_search(): - """Run validation tests for the attribute search strategy.""" - # Setup memory system - memory_system = create_memory_system( - logging_level="INFO", - memory_file=MEMORY_SAMPLE, - use_mock_redis=True, - ) - - # If memory system failed to load, exit - if not memory_system: - log_print(logger, "Failed to load memory system") - return - - # Setup search strategy - agent = memory_system.get_memory_agent(AGENT_ID) - search_strategy = AttributeSearchStrategy( - agent.stm_store, agent.im_store, agent.ltm_store - ) - - # Print strategy info - log_print(logger, f"Testing search strategy: {search_strategy.name()}") - log_print(logger, f"Description: {search_strategy.description()}") - - # Track test results - test_results = [] - - # Test 1: Basic content search - test_results.append(run_test( - search_strategy, - "Basic Content Search", - "meeting", - AGENT_ID, - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - # Test 2: Case sensitive search - test_results.append(run_test( - search_strategy, - "Case Sensitive Search", - "Meeting", - AGENT_ID, - case_sensitive=True, - )) - - # Test 3: Search by metadata type - test_results.append(run_test( - search_strategy, - "Search by Metadata Type", - {"metadata": {"type": "meeting"}}, - AGENT_ID, - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - # Test 4: Search with match_all - test_results.append(run_test( - search_strategy, - "Search with Match All", - {"content": "meeting", "metadata": {"type": "meeting", "importance": "high"}}, - AGENT_ID, - match_all=True, - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - # Test 5: Search specific memory tier - test_results.append(run_test( - search_strategy, - "Search in STM Tier Only", - "meeting", - AGENT_ID, - tier="stm", - expected_memory_ids=["meeting-123456-1", "meeting-123456-3"], - )) - - # Test 6: Search with regex - test_results.append(run_test( - search_strategy, - "Regex Search", - "secur.*patch", - AGENT_ID, - use_regex=True, - expected_memory_ids=["note-123456-4"], - )) - - # Test 7: Search with metadata filter - test_results.append(run_test( - search_strategy, - "Search with Metadata Filter", - "meeting", - AGENT_ID, - metadata_filter={"content.metadata.importance": "high"}, - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - # Test 8: Search in specific content fields - test_results.append(run_test( - search_strategy, - "Search in Specific Content Fields", - "project", - AGENT_ID, - content_fields=["content.content"], - expected_memory_ids=[ - "meeting-123456-1", - "contact-123456-5", - ], - )) - - # Test 9: Search in specific metadata fields - test_results.append(run_test( - search_strategy, - "Search in Specific Metadata Fields", - "project", - AGENT_ID, - metadata_fields=["content.metadata.tags"], - expected_memory_ids=[ - "meeting-123456-1", - "contact-123456-5", - ], - )) - - # Test 10: Search with complex query and filters - test_results.append(run_test( - search_strategy, - "Complex Search", - {"content": "security", "metadata": {"importance": "high"}}, - AGENT_ID, - metadata_filter={"content.metadata.source": "email"}, - match_all=True, - expected_memory_ids=["note-123456-4"], - )) - - # Test 11: Empty query handling - string - test_results.append(run_test( - search_strategy, - "Empty String Query", - "", - AGENT_ID, - expected_memory_ids=[], - )) - - # Test 12: Empty query handling - dict - test_results.append(run_test( - search_strategy, - "Empty Dict Query", - {}, - AGENT_ID, - expected_memory_ids=[], - )) - - # Test 13: Numeric value search - test_results.append(run_test( - search_strategy, - "Numeric Value Search", - 42, - AGENT_ID, - expected_memory_ids=[], - )) - - # Test 14: Boolean value search - test_results.append(run_test( - search_strategy, - "Boolean Value Search", - {"metadata": {"completed": True}}, - AGENT_ID, - expected_memory_ids=[], - )) - - # Test 15: Type conversion - searching string with numeric - test_results.append(run_test( - search_strategy, - "Type Conversion - String Field with Numeric", - 123, - AGENT_ID, - content_fields=["content.content"], - expected_memory_ids=[], - )) - - # Test 16: Invalid regex pattern handling - test_results.append(run_test( - search_strategy, - "Invalid Regex Pattern", - "[unclosed-bracket", - AGENT_ID, - use_regex=True, - expected_memory_ids=[], - )) - - # Test 17: Array field partial matching - test_results.append(run_test( - search_strategy, - "Array Field Partial Matching", - "dev", - AGENT_ID, - metadata_fields=["content.metadata.tags"], - expected_memory_ids=[ - "meeting-123456-3", - "task-123456-10", - ], - )) - - # Test 18: Special characters in search - test_results.append(run_test( - search_strategy, - "Special Characters in Search", - "meeting+notes", - AGENT_ID, - expected_memory_ids=[], - )) - - # Test 19: Multi-tier search - test_results.append(run_test( - search_strategy, - "Multi-Tier Search", - "important", - AGENT_ID, - # No tier specified means searching all tiers - expected_memory_ids=[], - )) - - # Test 20: Large result set limiting - test_results.append(run_test( - search_strategy, - "Large Result Set Limiting", - "a", # Common letter to match many memories - AGENT_ID, - limit=3, # Only show top 3 results - expected_memory_ids=[ - "meeting-123456-1", - "task-123456-2", - "meeting-123456-3", - ], - )) - - # ===== New tests for scoring methods ===== - log_print(logger, "\n=== SCORING METHOD COMPARISON TESTS ===") - - # Test 21: Comparing scoring methods on the same query - test_query = "meeting" - log_print(logger, f"\nComparing scoring methods for query: '{test_query}'") - - # Try each scoring method and collect results - test_results.append(run_test( - search_strategy, - "Default Length Ratio Scoring", - test_query, - AGENT_ID, - limit=5, - scoring_method="length_ratio", - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - test_results.append(run_test( - search_strategy, - "Term Frequency Scoring", - test_query, - AGENT_ID, - limit=5, - scoring_method="term_frequency", - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - test_results.append(run_test( - search_strategy, - "BM25 Scoring", - test_query, - AGENT_ID, - limit=5, - scoring_method="bm25", - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - test_results.append(run_test( - search_strategy, - "Binary Scoring", - test_query, - AGENT_ID, - limit=5, - scoring_method="binary", - expected_memory_ids=[ - "meeting-123456-1", - "meeting-123456-3", - "meeting-123456-6", - "meeting-123456-9", - "meeting-123456-11", - ], - )) - - # Test 22: Testing scoring on a document with repeated terms - test_with_repetition_query = "security" # Look for security-related memories - log_print( - logger, - f"\nComparing scoring methods for query with potential term repetition: '{test_with_repetition_query}'", - ) - - # Test with default length ratio scoring - test_results.append(run_test( - search_strategy, - "Default Scoring with Term Repetition", - test_with_repetition_query, - AGENT_ID, - limit=5, - expected_memory_ids=[ - "note-123456-4", - "note-123456-8", - "meeting-123456-11", - ], - )) - - # Test with term frequency scoring - should favor documents with more occurrences - test_results.append(run_test( - search_strategy, - "Term Frequency with Term Repetition", - test_with_repetition_query, - AGENT_ID, - limit=5, - scoring_method="term_frequency", - expected_memory_ids=[ - "note-123456-4", - "note-123456-8", - "meeting-123456-11", - ], - )) - - # Test with BM25 scoring - balances term frequency and document length - test_results.append(run_test( - search_strategy, - "BM25 with Term Repetition", - test_with_repetition_query, - AGENT_ID, - limit=5, - scoring_method="bm25", - expected_memory_ids=[ - "note-123456-4", - "note-123456-8", - "meeting-123456-11", - ], - )) - - # Test 23: Testing with a specialized search strategy for each method - log_print( - logger, - "\nTesting with dedicated search strategy instances for each scoring method", - ) - - # Create specialized strategy instances - term_freq_strategy = AttributeSearchStrategy( - agent.stm_store, - agent.im_store, - agent.ltm_store, - scoring_method="term_frequency", - ) - - bm25_strategy = AttributeSearchStrategy( - agent.stm_store, - agent.im_store, - agent.ltm_store, - scoring_method="bm25", - ) - - # Run test with specialized strategies - test_results.append(run_test( - term_freq_strategy, - "Using Term Frequency Strategy Instance", - "project", - AGENT_ID, - limit=5, - expected_memory_ids=[ - "meeting-123456-1", - "contact-123456-5", - ], - )) - - test_results.append(run_test( - bm25_strategy, - "Using BM25 Strategy Instance", - "project", - AGENT_ID, - limit=5, - expected_memory_ids=[ - "meeting-123456-1", - "contact-123456-5", - ], - )) - - # Test 24: Testing with a long document vs short document comparison - # Change from "detailed" (no matches) to "authentication system" (appears in memories of different lengths) - long_doc_query = "authentication system" - log_print( - logger, - f"\nComparing scoring methods for long vs short document query: '{long_doc_query}'", - ) - - # Compare each scoring method - test_results.append(run_test( - search_strategy, - "Length Ratio for Long Documents", - long_doc_query, - AGENT_ID, - limit=5, - scoring_method="length_ratio", - expected_memory_ids=[ - "meeting-123456-3", - "task-123456-7", - "task-123456-10", - ], - )) - - test_results.append(run_test( - search_strategy, - "Term Frequency for Long Documents", - long_doc_query, - AGENT_ID, - limit=5, - scoring_method="term_frequency", - expected_memory_ids=[ - "meeting-123456-3", - "task-123456-7", - "task-123456-10", - ], - )) - - test_results.append(run_test( - search_strategy, - "BM25 for Long Documents", - long_doc_query, - AGENT_ID, - limit=5, - scoring_method="bm25", - expected_memory_ids=[ - "meeting-123456-3", - "task-123456-7", - "task-123456-10", - ], - )) - - # Test 25: Testing with a query that matches varying document length and context - varying_length_query = "documentation" - log_print( - logger, - f"\nComparing scoring methods for documents of varying lengths: '{varying_length_query}'", - ) - - test_results.append(run_test( - search_strategy, - "Length Ratio for Documentation Query", - varying_length_query, - AGENT_ID, - limit=5, - scoring_method="length_ratio", - expected_memory_ids=[ - "task-123456-2", - "task-123456-7", - ], - )) - - test_results.append(run_test( - search_strategy, - "Term Frequency for Documentation Query", - varying_length_query, - AGENT_ID, - limit=5, - scoring_method="term_frequency", - expected_memory_ids=[ - "task-123456-2", - "task-123456-7", - ], - )) - - test_results.append(run_test( - search_strategy, - "BM25 for Documentation Query", - varying_length_query, - AGENT_ID, - limit=5, - scoring_method="bm25", - expected_memory_ids=[ - "task-123456-2", - "task-123456-7", - ], - )) - - # Display validation summary - log_print(logger, "\n\n=== VALIDATION SUMMARY ===") - log_print(logger, "-" * 80) - log_print(logger, "| {:<40} | {:<20} | {:<20} |".format("Test Name", "Status", "Validation Status")) - log_print(logger, "-" * 80) - - for result in test_results: - status = "PASS" if result["passed"] else "FAIL" - validation_status = status if result["has_validation"] else "N/A" - log_print(logger, "| {:<40} | {:<20} | {:<20} |".format( - result["test_name"][:40], - status, - validation_status - )) - - log_print(logger, "-" * 80) - - # Calculate overall statistics - validated_tests = [t for t in test_results if t["has_validation"]] - passed_tests = [t for t in validated_tests if t["passed"]] - - if validated_tests: - success_rate = len(passed_tests) / len(validated_tests) * 100 - log_print(logger, f"\nValidated Tests: {len(validated_tests)}") - log_print(logger, f"Passed Tests: {len(passed_tests)}") - log_print(logger, f"Failed Tests: {len(validated_tests) - len(passed_tests)}") - log_print(logger, f"Success Rate: {success_rate:.2f}%") - else: - log_print(logger, "\nNo tests with validation criteria were run.") - - -if __name__ == "__main__": - # Setup logging - logger = setup_logging("validate_attribute_search") - log_print(logger, "Starting Attribute Search Strategy Validation") - - validate_attribute_search() - - log_print(logger, "Validation Complete")