Skip to content

Commit 8ec90ed

Browse files
authored
Merge pull request #106 from Dooders/dev
Dev
2 parents 8a4574f + 964b521 commit 8ec90ed

File tree

2 files changed

+104
-24
lines changed

2 files changed

+104
-24
lines changed

memory/search/strategies/step.py

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,73 @@ def _get_memory_step(self, memory: Dict[str, Any]) -> Optional[int]:
314314

315315
return None
316316

317+
def _get_nested_value(self, obj: Dict[str, Any], path: str) -> Any:
318+
"""Get value from a nested dictionary using a dot-separated path.
319+
320+
Args:
321+
obj: Dictionary to extract value from
322+
path: Dot-separated path to the value (e.g., "content.metadata.importance")
323+
324+
Returns:
325+
Value found at the path, or None if not found
326+
"""
327+
if not path or not isinstance(obj, dict):
328+
return None
329+
330+
parts = path.split(".")
331+
current = obj
332+
333+
for part in parts:
334+
if not isinstance(current, dict) or part not in current:
335+
return None
336+
current = current[part]
337+
338+
return current
339+
340+
def _metadata_matches(
341+
self, memory: Dict[str, Any], metadata_filter: Dict[str, Any]
342+
) -> bool:
343+
"""Check if a memory matches the metadata filter.
344+
345+
Args:
346+
memory: Memory object to check
347+
metadata_filter: Metadata filter to apply
348+
349+
Returns:
350+
True if memory matches filter, False otherwise
351+
"""
352+
for key, filter_value in metadata_filter.items():
353+
memory_value = self._get_nested_value(memory, key)
354+
if memory_value is None:
355+
# Try to get from non-nested metadata
356+
memory_value = memory.get("metadata", {}).get(key)
357+
358+
# No matching value found
359+
if memory_value is None:
360+
logger.debug(f"No value found for key {key} in memory")
361+
return False
362+
363+
# Handle list/array values - check if filter_value is a subset of memory_value
364+
if isinstance(filter_value, list):
365+
if not isinstance(memory_value, list):
366+
# Convert to list if memory_value is a single value
367+
memory_value = [memory_value]
368+
369+
# Check if all items in filter_value are in memory_value
370+
if not all(item in memory_value for item in filter_value):
371+
logger.debug(
372+
f"List match failed for {key}: filter={filter_value}, memory={memory_value}"
373+
)
374+
return False
375+
# For other types, do a direct comparison
376+
elif memory_value != filter_value:
377+
logger.debug(
378+
f"Value mismatch for {key}: filter={filter_value}, memory={memory_value}"
379+
)
380+
return False
381+
382+
return True
383+
317384
def _filter_memories(
318385
self,
319386
memories: List[Dict[str, Any]],
@@ -364,11 +431,8 @@ def _filter_memories(
364431
continue
365432

366433
# Apply metadata filter
367-
if metadata_filter:
368-
memory_metadata = memory.get("metadata", {})
369-
if not all(
370-
memory_metadata.get(k) == v for k, v in metadata_filter.items()
371-
):
434+
if metadata_filter and len(metadata_filter) > 0:
435+
if not self._metadata_matches(memory, metadata_filter):
372436
continue
373437

374438
filtered.append(memory)
@@ -412,13 +476,24 @@ def _score_memories(
412476
# Normalize step distance (closer = higher score)
413477
# Adjust max_distance based on your simulation scale
414478
max_distance = step_params.get("step_range", 100) * 2
415-
normalized_distance = min(step_distance / max_distance, 1.0)
479+
# Avoid division by zero
480+
if max_distance > 0:
481+
normalized_distance = min(step_distance / max_distance, 1.0)
482+
else:
483+
normalized_distance = 1.0 if step_distance > 0 else 0.0
416484

417-
# Higher score for closer steps
485+
# Higher score for closer steps (1.0 for exact match, decreasing as distance increases)
418486
step_score = 1.0 - normalized_distance
419487

420-
# Apply step weight
421-
step_score = step_score * step_weight
488+
# Apply step weight (higher weight emphasizes proximity more)
489+
step_score = pow(step_score, 1.0 / max(step_weight, 0.001))
490+
491+
# Log the scoring calculations for debugging
492+
logger.debug(
493+
f"Memory {memory.get('id', 'unknown')}: step={memory_step}, "
494+
f"ref={reference_step}, distance={step_distance}, "
495+
f"normalized_distance={normalized_distance}, score={step_score}"
496+
)
422497

423498
# Create a copy of the memory to avoid modifying the original
424499
memory_copy = memory.copy()

validation/search/step/step_test_suite.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, logger=None):
3939
"test-agent-step-search-im-1": "c3d4e5f6g7h8i9j0k1l2",
4040
"test-agent-step-search-im-2": "d4e5f6g7h8i9j0k1l2m3",
4141
"test-agent-step-search-ltm-1": "e5f6g7h8i9j0k1l2m3n4",
42-
"test-agent-step-search-ltm-2": "f6g7h8i9j0k1l2m3n4o5"
42+
"test-agent-step-search-ltm-2": "f6g7h8i9j0k1l2m3n4o5",
4343
}
4444

4545
# Initialize base class
@@ -60,8 +60,9 @@ def run_basic_tests(self) -> None:
6060
{"start_step": 100, "end_step": 200},
6161
expected_memory_ids=[
6262
"test-agent-step-search-stm-1",
63-
"test-agent-step-search-stm-2"
63+
"test-agent-step-search-stm-2",
6464
],
65+
tier="stm",
6566
memory_checksum_map=self.memory_checksum_map,
6667
)
6768

@@ -72,7 +73,7 @@ def run_basic_tests(self) -> None:
7273
expected_memory_ids=[
7374
"test-agent-step-search-im-1",
7475
"test-agent-step-search-stm-2",
75-
"test-agent-step-search-im-2"
76+
"test-agent-step-search-im-2",
7677
],
7778
step_range=50,
7879
memory_checksum_map=self.memory_checksum_map,
@@ -84,7 +85,7 @@ def run_basic_tests(self) -> None:
8485
{"start_step": 200, "end_step": 300},
8586
expected_memory_ids=[
8687
"test-agent-step-search-im-1",
87-
"test-agent-step-search-im-2"
88+
"test-agent-step-search-im-2",
8889
],
8990
tier="im",
9091
memory_checksum_map=self.memory_checksum_map,
@@ -98,7 +99,7 @@ def run_basic_tests(self) -> None:
9899
"test-agent-step-search-stm-1",
99100
"test-agent-step-search-stm-2",
100101
"test-agent-step-search-ltm-1",
101-
"test-agent-step-search-ltm-2"
102+
"test-agent-step-search-ltm-2",
102103
],
103104
metadata_filter={"content.metadata.importance": "high"},
104105
memory_checksum_map=self.memory_checksum_map,
@@ -111,7 +112,7 @@ def run_basic_tests(self) -> None:
111112
expected_memory_ids=[
112113
"test-agent-step-search-im-1",
113114
"test-agent-step-search-stm-2",
114-
"test-agent-step-search-im-2"
115+
"test-agent-step-search-im-2",
115116
],
116117
step_range=50,
117118
step_weight=2.0,
@@ -127,7 +128,7 @@ def run_advanced_tests(self) -> None:
127128
expected_memory_ids=[
128129
"test-agent-step-search-stm-2",
129130
"test-agent-step-search-im-1",
130-
"test-agent-step-search-im-2"
131+
"test-agent-step-search-im-2",
131132
],
132133
memory_checksum_map=self.memory_checksum_map,
133134
)
@@ -139,7 +140,7 @@ def run_advanced_tests(self) -> None:
139140
expected_memory_ids=[
140141
"test-agent-step-search-im-2",
141142
"test-agent-step-search-im-1",
142-
"test-agent-step-search-stm-2"
143+
"test-agent-step-search-stm-2",
143144
],
144145
step_range=50,
145146
memory_checksum_map=self.memory_checksum_map,
@@ -151,9 +152,12 @@ def run_advanced_tests(self) -> None:
151152
{"start_step": 200, "end_step": 350},
152153
expected_memory_ids=[
153154
"test-agent-step-search-ltm-1",
154-
"test-agent-step-search-ltm-2"
155+
"test-agent-step-search-ltm-2",
155156
],
156-
metadata_filter={"content.metadata.type": "system", "content.metadata.importance": "high"},
157+
metadata_filter={
158+
"content.metadata.type": "state",
159+
"content.metadata.importance": "high",
160+
},
157161
memory_checksum_map=self.memory_checksum_map,
158162
)
159163

@@ -164,7 +168,7 @@ def run_advanced_tests(self) -> None:
164168
expected_memory_ids=[
165169
"test-agent-step-search-stm-1",
166170
"test-agent-step-search-stm-2",
167-
"test-agent-step-search-ltm-1"
171+
"test-agent-step-search-ltm-1",
168172
],
169173
metadata_filter={"content.metadata.importance": "high"},
170174
memory_checksum_map=self.memory_checksum_map,
@@ -176,7 +180,7 @@ def run_advanced_tests(self) -> None:
176180
{"start_step": 200, "end_step": 300},
177181
expected_memory_ids=[
178182
"test-agent-step-search-im-1",
179-
"test-agent-step-search-im-2"
183+
"test-agent-step-search-im-2",
180184
],
181185
metadata_filter={"content.metadata.tags": ["database", "api"]},
182186
memory_checksum_map=self.memory_checksum_map,
@@ -218,7 +222,7 @@ def run_edge_case_tests(self) -> None:
218222
"test-agent-step-search-im-1",
219223
"test-agent-step-search-im-2",
220224
"test-agent-step-search-ltm-1",
221-
"test-agent-step-search-ltm-2"
225+
"test-agent-step-search-ltm-2",
222226
],
223227
memory_checksum_map=self.memory_checksum_map,
224228
)
@@ -264,7 +268,8 @@ def run_edge_case_tests(self) -> None:
264268
{"start_step": 100, "end_step": 200},
265269
expected_memory_ids=[
266270
"test-agent-step-search-stm-1",
267-
"test-agent-step-search-stm-2"
271+
"test-agent-step-search-stm-2",
272+
"test-agent-step-search-im-1",
268273
],
269274
metadata_filter={},
270275
memory_checksum_map=self.memory_checksum_map,
@@ -286,4 +291,4 @@ def main():
286291

287292

288293
if __name__ == "__main__":
289-
main()
294+
main()

0 commit comments

Comments
 (0)