Skip to content

Commit 964b521

Browse files
committed
Refactor StepBasedSearchStrategy and update test suite for improved readability and consistency
This commit refactors the StepBasedSearchStrategy class by enhancing the formatting of method signatures and comments for better readability. It also updates the test suite to ensure consistent trailing commas in dictionary entries, improving code style and maintainability. These changes aim to streamline the codebase and enhance the clarity of both the strategy implementation and its associated tests.
1 parent c58a937 commit 964b521

File tree

2 files changed

+36
-30
lines changed

2 files changed

+36
-30
lines changed

memory/search/strategies/step.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -316,34 +316,36 @@ def _get_memory_step(self, memory: Dict[str, Any]) -> Optional[int]:
316316

317317
def _get_nested_value(self, obj: Dict[str, Any], path: str) -> Any:
318318
"""Get value from a nested dictionary using a dot-separated path.
319-
319+
320320
Args:
321321
obj: Dictionary to extract value from
322322
path: Dot-separated path to the value (e.g., "content.metadata.importance")
323-
323+
324324
Returns:
325325
Value found at the path, or None if not found
326326
"""
327327
if not path or not isinstance(obj, dict):
328328
return None
329-
330-
parts = path.split('.')
329+
330+
parts = path.split(".")
331331
current = obj
332-
332+
333333
for part in parts:
334334
if not isinstance(current, dict) or part not in current:
335335
return None
336336
current = current[part]
337-
337+
338338
return current
339339

340-
def _metadata_matches(self, memory: Dict[str, Any], metadata_filter: Dict[str, Any]) -> bool:
340+
def _metadata_matches(
341+
self, memory: Dict[str, Any], metadata_filter: Dict[str, Any]
342+
) -> bool:
341343
"""Check if a memory matches the metadata filter.
342-
344+
343345
Args:
344346
memory: Memory object to check
345347
metadata_filter: Metadata filter to apply
346-
348+
347349
Returns:
348350
True if memory matches filter, False otherwise
349351
"""
@@ -352,27 +354,31 @@ def _metadata_matches(self, memory: Dict[str, Any], metadata_filter: Dict[str, A
352354
if memory_value is None:
353355
# Try to get from non-nested metadata
354356
memory_value = memory.get("metadata", {}).get(key)
355-
357+
356358
# No matching value found
357359
if memory_value is None:
358360
logger.debug(f"No value found for key {key} in memory")
359361
return False
360-
362+
361363
# Handle list/array values - check if filter_value is a subset of memory_value
362364
if isinstance(filter_value, list):
363365
if not isinstance(memory_value, list):
364366
# Convert to list if memory_value is a single value
365367
memory_value = [memory_value]
366-
368+
367369
# Check if all items in filter_value are in memory_value
368370
if not all(item in memory_value for item in filter_value):
369-
logger.debug(f"List match failed for {key}: filter={filter_value}, memory={memory_value}")
371+
logger.debug(
372+
f"List match failed for {key}: filter={filter_value}, memory={memory_value}"
373+
)
370374
return False
371375
# For other types, do a direct comparison
372376
elif memory_value != filter_value:
373-
logger.debug(f"Value mismatch for {key}: filter={filter_value}, memory={memory_value}")
377+
logger.debug(
378+
f"Value mismatch for {key}: filter={filter_value}, memory={memory_value}"
379+
)
374380
return False
375-
381+
376382
return True
377383

378384
def _filter_memories(

validation/search/step/step_test_suite.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, logger=None):
3939
"test-agent-step-search-im-1": "c3d4e5f6g7h8i9j0k1l2",
4040
"test-agent-step-search-im-2": "d4e5f6g7h8i9j0k1l2m3",
4141
"test-agent-step-search-ltm-1": "e5f6g7h8i9j0k1l2m3n4",
42-
"test-agent-step-search-ltm-2": "f6g7h8i9j0k1l2m3n4o5"
42+
"test-agent-step-search-ltm-2": "f6g7h8i9j0k1l2m3n4o5",
4343
}
4444

4545
# Initialize base class
@@ -60,7 +60,7 @@ def run_basic_tests(self) -> None:
6060
{"start_step": 100, "end_step": 200},
6161
expected_memory_ids=[
6262
"test-agent-step-search-stm-1",
63-
"test-agent-step-search-stm-2"
63+
"test-agent-step-search-stm-2",
6464
],
6565
tier="stm",
6666
memory_checksum_map=self.memory_checksum_map,
@@ -73,7 +73,7 @@ def run_basic_tests(self) -> None:
7373
expected_memory_ids=[
7474
"test-agent-step-search-im-1",
7575
"test-agent-step-search-stm-2",
76-
"test-agent-step-search-im-2"
76+
"test-agent-step-search-im-2",
7777
],
7878
step_range=50,
7979
memory_checksum_map=self.memory_checksum_map,
@@ -85,7 +85,7 @@ def run_basic_tests(self) -> None:
8585
{"start_step": 200, "end_step": 300},
8686
expected_memory_ids=[
8787
"test-agent-step-search-im-1",
88-
"test-agent-step-search-im-2"
88+
"test-agent-step-search-im-2",
8989
],
9090
tier="im",
9191
memory_checksum_map=self.memory_checksum_map,
@@ -99,7 +99,7 @@ def run_basic_tests(self) -> None:
9999
"test-agent-step-search-stm-1",
100100
"test-agent-step-search-stm-2",
101101
"test-agent-step-search-ltm-1",
102-
"test-agent-step-search-ltm-2"
102+
"test-agent-step-search-ltm-2",
103103
],
104104
metadata_filter={"content.metadata.importance": "high"},
105105
memory_checksum_map=self.memory_checksum_map,
@@ -112,7 +112,7 @@ def run_basic_tests(self) -> None:
112112
expected_memory_ids=[
113113
"test-agent-step-search-im-1",
114114
"test-agent-step-search-stm-2",
115-
"test-agent-step-search-im-2"
115+
"test-agent-step-search-im-2",
116116
],
117117
step_range=50,
118118
step_weight=2.0,
@@ -128,7 +128,7 @@ def run_advanced_tests(self) -> None:
128128
expected_memory_ids=[
129129
"test-agent-step-search-stm-2",
130130
"test-agent-step-search-im-1",
131-
"test-agent-step-search-im-2"
131+
"test-agent-step-search-im-2",
132132
],
133133
memory_checksum_map=self.memory_checksum_map,
134134
)
@@ -140,7 +140,7 @@ def run_advanced_tests(self) -> None:
140140
expected_memory_ids=[
141141
"test-agent-step-search-im-2",
142142
"test-agent-step-search-im-1",
143-
"test-agent-step-search-stm-2"
143+
"test-agent-step-search-stm-2",
144144
],
145145
step_range=50,
146146
memory_checksum_map=self.memory_checksum_map,
@@ -152,11 +152,11 @@ def run_advanced_tests(self) -> None:
152152
{"start_step": 200, "end_step": 350},
153153
expected_memory_ids=[
154154
"test-agent-step-search-ltm-1",
155-
"test-agent-step-search-ltm-2"
155+
"test-agent-step-search-ltm-2",
156156
],
157157
metadata_filter={
158158
"content.metadata.type": "state",
159-
"content.metadata.importance": "high"
159+
"content.metadata.importance": "high",
160160
},
161161
memory_checksum_map=self.memory_checksum_map,
162162
)
@@ -168,7 +168,7 @@ def run_advanced_tests(self) -> None:
168168
expected_memory_ids=[
169169
"test-agent-step-search-stm-1",
170170
"test-agent-step-search-stm-2",
171-
"test-agent-step-search-ltm-1"
171+
"test-agent-step-search-ltm-1",
172172
],
173173
metadata_filter={"content.metadata.importance": "high"},
174174
memory_checksum_map=self.memory_checksum_map,
@@ -180,7 +180,7 @@ def run_advanced_tests(self) -> None:
180180
{"start_step": 200, "end_step": 300},
181181
expected_memory_ids=[
182182
"test-agent-step-search-im-1",
183-
"test-agent-step-search-im-2"
183+
"test-agent-step-search-im-2",
184184
],
185185
metadata_filter={"content.metadata.tags": ["database", "api"]},
186186
memory_checksum_map=self.memory_checksum_map,
@@ -222,7 +222,7 @@ def run_edge_case_tests(self) -> None:
222222
"test-agent-step-search-im-1",
223223
"test-agent-step-search-im-2",
224224
"test-agent-step-search-ltm-1",
225-
"test-agent-step-search-ltm-2"
225+
"test-agent-step-search-ltm-2",
226226
],
227227
memory_checksum_map=self.memory_checksum_map,
228228
)
@@ -269,7 +269,7 @@ def run_edge_case_tests(self) -> None:
269269
expected_memory_ids=[
270270
"test-agent-step-search-stm-1",
271271
"test-agent-step-search-stm-2",
272-
"test-agent-step-search-im-1"
272+
"test-agent-step-search-im-1",
273273
],
274274
metadata_filter={},
275275
memory_checksum_map=self.memory_checksum_map,
@@ -291,4 +291,4 @@ def main():
291291

292292

293293
if __name__ == "__main__":
294-
main()
294+
main()

0 commit comments

Comments
 (0)