Skip to content

Commit 3b07213

Browse files
fix: Corrected payload recovery and saving on Chorma.
1 parent 60f5ad2 commit 3b07213

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

src/memos/vec_dbs/chroma.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import json
3+
34
from typing import Any
45

56
from memos.configs.vec_db import ChromaVecDBConfig
@@ -8,6 +9,7 @@
89
from memos.vec_dbs.base import BaseVecDB
910
from memos.vec_dbs.item import VecDBItem
1011

12+
1113
logger = get_logger(__name__)
1214

1315

@@ -50,7 +52,7 @@ def serialize_metadata(d: dict) -> dict:
5052
"""Return a copy of the dict with list/dict values converted to JSON strings."""
5153
result = {}
5254
for key, value in d.items():
53-
if isinstance(value, (dict, list)):
55+
if isinstance(value, dict | list):
5456
result[key] = json.dumps(value)
5557
else:
5658
result[key] = value
@@ -64,7 +66,7 @@ def deserialize_metadata(d: dict) -> dict:
6466
if isinstance(value, str):
6567
try:
6668
parsed = json.loads(value)
67-
if isinstance(parsed, (dict, list)):
69+
if isinstance(parsed, dict | list):
6870
result[key] = parsed
6971
else:
7072
result[key] = value
@@ -111,7 +113,7 @@ def collection_exists(self, name: str) -> bool:
111113
return False
112114

113115
def search(
114-
self, query_vector: list[float], top_k: int, filter: dict[str, Any] | None = None
116+
self, query_vector: list[float], top_k: int, filter: dict[str, Any] | None = None
115117
) -> list[VecDBItem]:
116118
"""
117119
Search for similar items in the database.
@@ -133,12 +135,14 @@ def search(
133135
logger.info(f"ChromaDb search completed with {len(response)} results.")
134136
return [
135137
VecDBItem(
136-
id=response["ids"][idx],
137-
vector=response["embeddings"][idx] if response["embeddings"] else None,
138-
payload=self.deserialize_metadata(response["metadatas"][idx]) if response["metadatas"] else None,
139-
score=response["distances"][idx] if response["distances"] else None,
138+
id=response["ids"][0][idx],
139+
vector=response["embeddings"][0][idx] if response["embeddings"] else None,
140+
payload=self.deserialize_metadata(response["metadatas"][0][idx])
141+
if response["metadatas"]
142+
else None,
143+
score=response["distances"][0][idx] if response["distances"][0] else None,
140144
)
141-
for idx, _ in enumerate(response["ids"])
145+
for idx, _ in enumerate(response["ids"][0])
142146
]
143147

144148
def get_by_id(self, id: str) -> VecDBItem | None:
@@ -165,7 +169,9 @@ def get_by_ids(self, ids: list[str]) -> list[VecDBItem]:
165169
return [
166170
VecDBItem(
167171
id=response["ids"][idx],
168-
vector=self.deserialize_metadata(response["embeddings"][idx]) if response["embeddings"] else None,
172+
vector=self.deserialize_metadata(response["embeddings"][idx])
173+
if response["embeddings"]
174+
else None,
169175
payload=response["metadatas"][idx] if response["metadatas"] else None,
170176
)
171177
for idx, _ in enumerate(response["ids"])
@@ -192,7 +198,9 @@ def get_by_filter(self, filter: dict[str, Any], limit: int = 100) -> list[VecDBI
192198
return [
193199
VecDBItem(
194200
id=response["ids"][idx],
195-
vector=self.deserialize_metadata(response["embeddings"][idx]) if response["embeddings"] else None,
201+
vector=self.deserialize_metadata(response["embeddings"][idx])
202+
if response["embeddings"]
203+
else None,
196204
payload=response["metadatas"][idx] if response["metadatas"] else None,
197205
)
198206
for idx, _ in enumerate(response["ids"])
@@ -227,7 +235,7 @@ def add(self, data: list[VecDBItem | dict[str, Any]]) -> None:
227235
item = VecDBItem.from_dict(item)
228236
ids.append(str(item.id))
229237
embeddings.append(item.vector)
230-
metadatas.append(self.serialize_metadata(item.payload.get("metadata")))
238+
metadatas.append(self.serialize_metadata(item.payload))
231239
documents.append(item.payload.get("memory"))
232240

233241
self.get_collection().upsert(
@@ -250,7 +258,9 @@ def update(self, id: str, data: VecDBItem | dict[str, Any]) -> None:
250258
)
251259
else:
252260
# For payload-only updates
253-
self.get_collection().upsert(ids=[id], metadatas=[self.serialize_metadata(data.payload.get("metadata"))])
261+
self.get_collection().upsert(
262+
ids=[id], metadatas=[self.serialize_metadata(data.payload.get("metadata"))]
263+
)
254264

255265
def upsert(self, data: list[VecDBItem | dict[str, Any]]) -> None:
256266
"""
@@ -274,4 +284,3 @@ def ensure_payload_indexes(self, fields: list[str]) -> None:
274284
fields (list[str]): List of field names to index (as keyword).
275285
"""
276286
# chromadb does not implement crete index
277-
pass

0 commit comments

Comments
 (0)