11import base64
22import json
3+
34from typing import Any
45
56from memos .configs .vec_db import ChromaVecDBConfig
89from memos .vec_dbs .base import BaseVecDB
910from memos .vec_dbs .item import VecDBItem
1011
12+
1113logger = get_logger (__name__ )
1214
1315
@@ -50,7 +52,7 @@ def serialize_metadata(d: dict) -> dict:
5052 """Return a copy of the dict with list/dict values converted to JSON strings."""
5153 result = {}
5254 for key , value in d .items ():
53- if isinstance (value , ( dict , list ) ):
55+ if isinstance (value , dict | list ):
5456 result [key ] = json .dumps (value )
5557 else :
5658 result [key ] = value
@@ -64,7 +66,7 @@ def deserialize_metadata(d: dict) -> dict:
6466 if isinstance (value , str ):
6567 try :
6668 parsed = json .loads (value )
67- if isinstance (parsed , ( dict , list ) ):
69+ if isinstance (parsed , dict | list ):
6870 result [key ] = parsed
6971 else :
7072 result [key ] = value
@@ -111,7 +113,7 @@ def collection_exists(self, name: str) -> bool:
111113 return False
112114
113115 def search (
114- self , query_vector : list [float ], top_k : int , filter : dict [str , Any ] | None = None
116+ self , query_vector : list [float ], top_k : int , filter : dict [str , Any ] | None = None
115117 ) -> list [VecDBItem ]:
116118 """
117119 Search for similar items in the database.
@@ -133,12 +135,14 @@ def search(
133135 logger .info (f"ChromaDb search completed with { len (response )} results." )
134136 return [
135137 VecDBItem (
136- id = response ["ids" ][idx ],
137- vector = response ["embeddings" ][idx ] if response ["embeddings" ] else None ,
138- payload = self .deserialize_metadata (response ["metadatas" ][idx ]) if response ["metadatas" ] else None ,
139- score = response ["distances" ][idx ] if response ["distances" ] else None ,
138+ id = response ["ids" ][0 ][idx ],
139+ vector = response ["embeddings" ][0 ][idx ] if response ["embeddings" ] else None ,
140+ payload = self .deserialize_metadata (response ["metadatas" ][0 ][idx ])
141+ if response ["metadatas" ]
142+ else None ,
143+ score = response ["distances" ][0 ][idx ] if response ["distances" ][0 ] else None ,
140144 )
141- for idx , _ in enumerate (response ["ids" ])
145+ for idx , _ in enumerate (response ["ids" ][ 0 ] )
142146 ]
143147
144148 def get_by_id (self , id : str ) -> VecDBItem | None :
@@ -165,7 +169,9 @@ def get_by_ids(self, ids: list[str]) -> list[VecDBItem]:
165169 return [
166170 VecDBItem (
167171 id = response ["ids" ][idx ],
168- vector = self .deserialize_metadata (response ["embeddings" ][idx ]) if response ["embeddings" ] else None ,
172+ vector = self .deserialize_metadata (response ["embeddings" ][idx ])
173+ if response ["embeddings" ]
174+ else None ,
169175 payload = response ["metadatas" ][idx ] if response ["metadatas" ] else None ,
170176 )
171177 for idx , _ in enumerate (response ["ids" ])
@@ -192,7 +198,9 @@ def get_by_filter(self, filter: dict[str, Any], limit: int = 100) -> list[VecDBI
192198 return [
193199 VecDBItem (
194200 id = response ["ids" ][idx ],
195- vector = self .deserialize_metadata (response ["embeddings" ][idx ]) if response ["embeddings" ] else None ,
201+ vector = self .deserialize_metadata (response ["embeddings" ][idx ])
202+ if response ["embeddings" ]
203+ else None ,
196204 payload = response ["metadatas" ][idx ] if response ["metadatas" ] else None ,
197205 )
198206 for idx , _ in enumerate (response ["ids" ])
@@ -227,7 +235,7 @@ def add(self, data: list[VecDBItem | dict[str, Any]]) -> None:
227235 item = VecDBItem .from_dict (item )
228236 ids .append (str (item .id ))
229237 embeddings .append (item .vector )
230- metadatas .append (self .serialize_metadata (item .payload . get ( "metadata" ) ))
238+ metadatas .append (self .serialize_metadata (item .payload ))
231239 documents .append (item .payload .get ("memory" ))
232240
233241 self .get_collection ().upsert (
@@ -250,7 +258,9 @@ def update(self, id: str, data: VecDBItem | dict[str, Any]) -> None:
250258 )
251259 else :
252260 # For payload-only updates
253- self .get_collection ().upsert (ids = [id ], metadatas = [self .serialize_metadata (data .payload .get ("metadata" ))])
261+ self .get_collection ().upsert (
262+ ids = [id ], metadatas = [self .serialize_metadata (data .payload .get ("metadata" ))]
263+ )
254264
255265 def upsert (self , data : list [VecDBItem | dict [str , Any ]]) -> None :
256266 """
@@ -274,4 +284,3 @@ def ensure_payload_indexes(self, fields: list[str]) -> None:
274284 fields (list[str]): List of field names to index (as keyword).
275285 """
276286 # chromadb does not implement crete index
277- pass
0 commit comments