Skip to content

Commit 68356cd

Browse files
committed
ruff format
1 parent 9d6ebc5 commit 68356cd

File tree

3 files changed

+357
-266
lines changed

3 files changed

+357
-266
lines changed

src/memos/graph_dbs/neo4j.py

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,16 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
4848
def _flatten_info_fields(metadata: dict[str, Any]) -> dict[str, Any]:
4949
"""
5050
Flatten the 'info' field in metadata to the top level.
51-
51+
5252
If metadata contains an 'info' field that is a dictionary, all its key-value pairs
5353
will be moved to the top level of metadata, and the 'info' field will be removed.
54-
54+
5555
Args:
5656
metadata: Dictionary that may contain an 'info' field
57-
57+
5858
Returns:
5959
Dictionary with 'info' fields flattened to top level
60-
60+
6161
Example:
6262
Input: {"user_id": "xxx", "info": {"A": "value1", "B": "value2"}}
6363
Output: {"user_id": "xxx", "A": "value1", "B": "value2"}
@@ -195,7 +195,7 @@ def remove_oldest_memory(
195195
session.run(query)
196196

197197
def add_node(
198-
self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None
198+
self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None
199199
) -> None:
200200
logger.info(f"[add_node] metadata: {metadata},info: {metadata.get('info')}")
201201
print(f"[add_node] metadata: {metadata},info: {metadata.get('info')}")
@@ -206,7 +206,7 @@ def add_node(
206206

207207
# Safely process metadata
208208
metadata = _prepare_node_metadata(metadata)
209-
209+
210210
# Flatten info fields to top level (for Neo4j flat structure)
211211
metadata = _flatten_info_fields(metadata)
212212

@@ -226,7 +226,6 @@ def add_node(
226226
if metadata["sources"]:
227227
for idx in range(len(metadata["sources"])):
228228
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
229-
print("111add_node id:",id)
230229

231230
with self.driver.session(database=self.db_name) as session:
232231
session.run(
@@ -593,7 +592,7 @@ def get_children_with_embeddings(
593592
]
594593

595594
def get_path(
596-
self, source_id: str, target_id: str, max_depth: int = 3, user_name: str | None = None
595+
self, source_id: str, target_id: str, max_depth: int = 3, user_name: str | None = None
597596
) -> list[str]:
598597
"""
599598
Get the path of nodes from source to target within a limited depth.
@@ -687,17 +686,17 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
687686

688687
# Search / recall operations
689688
def search_by_embedding(
690-
self,
691-
vector: list[float],
692-
top_k: int = 5,
693-
scope: str | None = None,
694-
status: str | None = None,
695-
threshold: float | None = None,
696-
search_filter: dict | None = None,
697-
user_name: str | None = None,
698-
filter: dict | None = None,
699-
knowledgebase_ids: list[str] | None = None,
700-
**kwargs,
689+
self,
690+
vector: list[float],
691+
top_k: int = 5,
692+
scope: str | None = None,
693+
status: str | None = None,
694+
threshold: float | None = None,
695+
search_filter: dict | None = None,
696+
user_name: str | None = None,
697+
filter: dict | None = None,
698+
knowledgebase_ids: list[str] | None = None,
699+
**kwargs,
701700
) -> list[dict]:
702701
"""
703702
Retrieve node IDs based on vector similarity.
@@ -808,7 +807,11 @@ def search_by_embedding(
808807
return records
809808

810809
def get_by_metadata(
811-
self, filters: list[dict[str, Any]], user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None
810+
self,
811+
filters: list[dict[str, Any]],
812+
user_name: str | None = None,
813+
filter: dict | None = None,
814+
knowledgebase_ids: list[str] | None = None,
812815
) -> list[str]:
813816
"""
814817
TODO:
@@ -834,8 +837,12 @@ def get_by_metadata(
834837
- Supports structured querying such as tag/category/importance/time filtering.
835838
- Can be used for faceted recall or prefiltering before embedding rerank.
836839
"""
837-
logger.info(f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}")
838-
print(f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}")
840+
logger.info(
841+
f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}"
842+
)
843+
print(
844+
f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids}"
845+
)
839846
user_name = user_name if user_name else self.config.user_name
840847
where_clauses = []
841848
params = {}
@@ -1096,7 +1103,13 @@ def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> No
10961103
target_id=edge["target"],
10971104
)
10981105

1099-
def get_all_memory_items(self, scope: str, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, **kwargs) -> list[dict]:
1106+
def get_all_memory_items(
1107+
self,
1108+
scope: str,
1109+
filter: dict | None = None,
1110+
knowledgebase_ids: list[str] | None = None,
1111+
**kwargs,
1112+
) -> list[dict]:
11001113
"""
11011114
Retrieve all memory items of a specific memory_type.
11021115
@@ -1109,8 +1122,12 @@ def get_all_memory_items(self, scope: str, filter: dict | None = None, knowledge
11091122
Returns:
11101123
list[dict]: Full list of memory items under this scope.
11111124
"""
1112-
logger.info(f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids}")
1113-
print(f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids}")
1125+
logger.info(
1126+
f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids}"
1127+
)
1128+
print(
1129+
f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids}"
1130+
)
11141131

11151132
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
11161133
if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}:
@@ -1316,11 +1333,11 @@ def _index_exists(self, index_name: str) -> bool:
13161333
return False
13171334

13181335
def _build_user_name_and_kb_ids_conditions_cypher(
1319-
self,
1320-
user_name: str | None,
1321-
knowledgebase_ids: list[str] | None,
1322-
default_user_name: str | None = None,
1323-
node_alias: str = "node",
1336+
self,
1337+
user_name: str | None,
1338+
knowledgebase_ids: list[str] | None,
1339+
default_user_name: str | None = None,
1340+
node_alias: str = "node",
13241341
) -> tuple[list[str], dict[str, Any]]:
13251342
"""
13261343
Build user_name and knowledgebase_ids conditions for Cypher queries.
@@ -1354,10 +1371,10 @@ def _build_user_name_and_kb_ids_conditions_cypher(
13541371
return user_name_conditions, params
13551372

13561373
def _build_filter_conditions_cypher(
1357-
self,
1358-
filter: dict | None,
1359-
param_counter_start: int = 0,
1360-
node_alias: str = "node",
1374+
self,
1375+
filter: dict | None,
1376+
param_counter_start: int = 0,
1377+
node_alias: str = "node",
13611378
) -> tuple[list[str], dict[str, Any]]:
13621379
"""
13631380
Build filter conditions for Cypher queries.
@@ -1396,12 +1413,7 @@ def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[s
13961413
for op, op_value in value.items():
13971414
if op in ("gt", "lt", "gte", "lte"):
13981415
# Map operator to Cypher operator
1399-
cypher_op_map = {
1400-
"gt": ">",
1401-
"lt": "<",
1402-
"gte": ">=",
1403-
"lte": "<="
1404-
}
1416+
cypher_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="}
14051417
cypher_op = cypher_op_map[op]
14061418

14071419
# All fields are stored as flat properties in Neo4j
@@ -1412,15 +1424,19 @@ def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[s
14121424
# Check if field is a date field (created_at, updated_at, etc.)
14131425
# Use datetime() function for date comparisons
14141426
if key in ("created_at", "updated_at") or key.endswith("_at"):
1415-
condition_parts.append(f"{node_alias}.{key} {cypher_op} datetime(${param_name})")
1427+
condition_parts.append(
1428+
f"{node_alias}.{key} {cypher_op} datetime(${param_name})"
1429+
)
14161430
else:
1417-
condition_parts.append(f"{node_alias}.{key} {cypher_op} ${param_name}")
1431+
condition_parts.append(
1432+
f"{node_alias}.{key} {cypher_op} ${param_name}"
1433+
)
14181434
elif op == "contains":
14191435
# Handle contains operator (for array fields like tags, sources)
14201436
param_name = f"filter_{key}_{op}_{param_counter[0]}"
14211437
param_counter[0] += 1
14221438
params[param_name] = op_value
1423-
1439+
14241440
# For array fields, check if element is in array
14251441
if key in ("tags", "sources"):
14261442
condition_parts.append(f"${param_name} IN {node_alias}.{key}")

0 commit comments

Comments
 (0)