|
| 1 | +import asyncio |
| 2 | +import contextlib |
| 3 | +import logging |
| 4 | +import os |
| 5 | +import socket |
| 6 | +import sys |
| 7 | +from typing import Any, Literal, Optional, Sequence, cast |
| 8 | +from urllib.parse import urlparse |
| 9 | + |
| 10 | +import chromadb |
| 11 | +from filelock import AsyncFileLock |
| 12 | + |
| 13 | +from vectorcode.chunking import Chunk, TreeSitterChunker |
| 14 | +from vectorcode.cli_utils import Config, LockManager, QueryInclude |
| 15 | +from vectorcode.database import DatabaseConnectorBase |
| 16 | +from vectorcode.database.chroma_common import convert_chroma_query_results |
| 17 | +from vectorcode.database.errors import CollectionNotFoundError |
| 18 | +from vectorcode.database.types import ( |
| 19 | + CollectionContent, |
| 20 | + CollectionInfo, |
| 21 | + QueryResult, |
| 22 | + ResultType, |
| 23 | + VectoriseStats, |
| 24 | +) |
| 25 | +from vectorcode.database.utils import get_collection_id, get_uuid, hash_file |
| 26 | + |
| 27 | +if not chromadb.__version__.startswith("1."): |
| 28 | + logging.error( |
| 29 | + f""" |
| 30 | +Found ChromaDB {chromadb.__version__}, which is incompatible wiht your VectorCode installation. Please install `vectorcode`. |
| 31 | +
|
| 32 | +For example: |
| 33 | +uv tool install vectorcode |
| 34 | +""" |
| 35 | + ) |
| 36 | + sys.exit(1) |
| 37 | + |
| 38 | + |
| 39 | +from chromadb import Collection |
| 40 | +from chromadb.api import ClientAPI |
| 41 | +from chromadb.config import APIVersion, Settings |
| 42 | + |
| 43 | +logger = logging.getLogger(name=__name__) |
| 44 | + |
| 45 | +SupportedClientType = Literal["http"] | Literal["persistent"] |
| 46 | + |
| 47 | +_SUPPORTED_CLIENT_TYPE: set[SupportedClientType] = {"http", "persistent"} |
| 48 | + |
| 49 | +_default_settings: dict[str, Any] = { |
| 50 | + "db_url": None, |
| 51 | + "db_path": os.path.expanduser("~/.local/share/vectorcode/chromadb/"), |
| 52 | + "db_log_path": os.path.expanduser("~/.local/share/vectorcode/"), |
| 53 | + "db_settings": {}, |
| 54 | + "hnsw": {"hnsw:M": 64}, |
| 55 | +} |
| 56 | + |
| 57 | + |
| 58 | +class ChromaDBConnector(DatabaseConnectorBase): |
| 59 | + def __init__(self, configs: Config): |
| 60 | + super().__init__(configs) |
| 61 | + params = _default_settings.copy() |
| 62 | + params.update(self._configs.db_params.copy()) |
| 63 | + self._configs.db_params = params |
| 64 | + |
| 65 | + self._lock: AsyncFileLock | None = None |
| 66 | + self._client: ClientAPI | None = None |
| 67 | + self._client_type: SupportedClientType |
| 68 | + |
| 69 | + def _create_client(self) -> ClientAPI: |
| 70 | + global _SUPPORTED_CLIENT_TYPE |
| 71 | + settings: dict[str, Any] = {"anonymized_telemetry": False} |
| 72 | + db_params = self._configs.db_params |
| 73 | + settings.update(db_params["db_settings"]) |
| 74 | + if db_params.get("db_url"): |
| 75 | + parsed_url = urlparse(db_params["db_url"]) |
| 76 | + |
| 77 | + settings["chroma_server_host"] = settings.get( |
| 78 | + "chroma_server_host", parsed_url.hostname or "127.0.0.1" |
| 79 | + ) |
| 80 | + settings["chroma_server_http_port"] = settings.get( |
| 81 | + "chroma_server_http_port", parsed_url.port or 8000 |
| 82 | + ) |
| 83 | + settings["chroma_server_ssl_enabled"] = settings.get( |
| 84 | + "chroma_server_ssl_enabled", parsed_url.scheme == "https" |
| 85 | + ) |
| 86 | + settings["chroma_server_api_default_path"] = settings.get( |
| 87 | + "chroma_server_api_default_path", parsed_url.path or APIVersion.V2 |
| 88 | + ) |
| 89 | + settings_obj = Settings(**settings) |
| 90 | + logger.info( |
| 91 | + f"Created chromadb.HttpClient from the following settings: {settings_obj}" |
| 92 | + ) |
| 93 | + self._client = chromadb.HttpClient( |
| 94 | + host=parsed_url.hostname, |
| 95 | + port=parsed_url.port, |
| 96 | + ssl=parsed_url.scheme == "https", |
| 97 | + settings=settings_obj, |
| 98 | + ) |
| 99 | + self._client_type = "http" |
| 100 | + else: |
| 101 | + logger.info( |
| 102 | + f"Created chromadb.PersistentClient at `{db_params['db_path']}` from the following settings: {settings}" |
| 103 | + ) |
| 104 | + os.makedirs(db_params["db_path"], exist_ok=True) |
| 105 | + self._client = chromadb.PersistentClient(path=db_params["db_path"]) |
| 106 | + |
| 107 | + self._client_type = "persistent" |
| 108 | + assert self._client_type in _SUPPORTED_CLIENT_TYPE |
| 109 | + return self._client |
| 110 | + |
| 111 | + async def get_client(self) -> ClientAPI: |
| 112 | + if self._client is None: |
| 113 | + self._create_client() |
| 114 | + assert self._client is not None |
| 115 | + if self._client_type == "persistent": |
| 116 | + async with LockManager().get_lock( |
| 117 | + self._configs.db_params["db_path"] |
| 118 | + ) as lock: |
| 119 | + self._lock = lock |
| 120 | + return self._client |
| 121 | + |
| 122 | + @contextlib.asynccontextmanager |
| 123 | + async def maybe_lock(self): |
| 124 | + """ |
| 125 | + Acquire a file (dir) lock if using persistent client. |
| 126 | + """ |
| 127 | + if self._lock is not None: |
| 128 | + await self._lock.acquire() |
| 129 | + yield |
| 130 | + if self._lock is not None: |
| 131 | + await self._lock.release() |
| 132 | + |
| 133 | + async def _create_or_get_collection( |
| 134 | + self, collection_path: str, allow_create: bool = False |
| 135 | + ) -> Collection: |
| 136 | + """ |
| 137 | + This method should be used by ChromaDB methods that are expected to **create a collection when not found**. |
| 138 | + For other methods, just use `client.get_collection` and let it fail if the collection doesn't exist. |
| 139 | + """ |
| 140 | + |
| 141 | + collection_meta: dict[str, str | int] = { |
| 142 | + "path": os.path.abspath(str(self._configs.project_root)), |
| 143 | + "hostname": socket.gethostname(), |
| 144 | + "created-by": "VectorCode", |
| 145 | + "username": os.environ.get( |
| 146 | + "USER", os.environ.get("USERNAME", "DEFAULT_USER") |
| 147 | + ), |
| 148 | + "embedding_function": self._configs.embedding_function, |
| 149 | + } |
| 150 | + db_params = self._configs.db_params |
| 151 | + user_hnsw = db_params.get("hnsw", {}) |
| 152 | + for key in user_hnsw.keys(): |
| 153 | + meta_field_name: str = key |
| 154 | + if not meta_field_name.startswith("hnsw:"): |
| 155 | + meta_field_name = f"hnsw:{meta_field_name}" |
| 156 | + if user_hnsw.get(key) is not None: |
| 157 | + collection_meta[meta_field_name] = user_hnsw[key] |
| 158 | + |
| 159 | + async with self.maybe_lock(): |
| 160 | + collection_id = get_collection_id(collection_path) |
| 161 | + client = await self.get_client() |
| 162 | + if not allow_create: |
| 163 | + try: |
| 164 | + return client.get_collection(collection_id) |
| 165 | + except ValueError as e: |
| 166 | + raise CollectionNotFoundError( |
| 167 | + f"There's no existing collection for {collection_path} in ChromaDB with the following setup: {self._configs.db_params}" |
| 168 | + ) from e |
| 169 | + col = client.get_or_create_collection( |
| 170 | + collection_id, metadata=collection_meta |
| 171 | + ) |
| 172 | + for key in collection_meta.keys(): |
| 173 | + # validate metadata |
| 174 | + assert collection_meta[key] == col.metadata.get(key), ( |
| 175 | + f"Metadata field {key} mismatch!" |
| 176 | + ) |
| 177 | + |
| 178 | + return col |
| 179 | + |
| 180 | + async def query(self) -> list[QueryResult]: |
| 181 | + collection = await self._create_or_get_collection( |
| 182 | + str(self._configs.project_root), False |
| 183 | + ) |
| 184 | + |
| 185 | + assert self._configs.query is not None |
| 186 | + assert len(self._configs.query), "Keywords cannot be empty" |
| 187 | + keywords_embeddings = self.get_embedding(self._configs.query) |
| 188 | + |
| 189 | + query_count = self._configs.n_result or (await self.count(ResultType.chunk)) |
| 190 | + query_filter = None |
| 191 | + if len(self._configs.query_exclude): |
| 192 | + query_filter = cast( |
| 193 | + chromadb.Where, {"path": {"$nin": list(self._configs.query_exclude)}} |
| 194 | + ) |
| 195 | + if QueryInclude.chunk in self._configs.include: |
| 196 | + if query_filter is None: |
| 197 | + query_filter = cast(chromadb.Where, {"start": {"$gte": 0}}) |
| 198 | + else: |
| 199 | + query_filter = cast( |
| 200 | + chromadb.Where, |
| 201 | + {"$and": [query_filter.copy(), {"start": {"$gte": 0}}]}, |
| 202 | + ) |
| 203 | + |
| 204 | + async with self.maybe_lock(): |
| 205 | + raw_result = await asyncio.to_thread( |
| 206 | + collection.query, |
| 207 | + include=[ |
| 208 | + "metadatas", |
| 209 | + "documents", |
| 210 | + "distances", |
| 211 | + ], |
| 212 | + query_embeddings=keywords_embeddings, |
| 213 | + where=query_filter, |
| 214 | + n_results=query_count, |
| 215 | + ) |
| 216 | + return convert_chroma_query_results(raw_result, self._configs.query) |
| 217 | + |
| 218 | + async def vectorise( |
| 219 | + self, file_path: str, chunker: TreeSitterChunker | None = None |
| 220 | + ) -> VectoriseStats: |
| 221 | + collection_path = str(self._configs.project_root) |
| 222 | + collection = await self._create_or_get_collection( |
| 223 | + collection_path, allow_create=True |
| 224 | + ) |
| 225 | + chunker = chunker or TreeSitterChunker(self._configs) |
| 226 | + |
| 227 | + chunks = tuple(chunker.chunk(file_path)) |
| 228 | + embeddings = self.get_embedding(list(i.text for i in chunks)) |
| 229 | + if len(embeddings) == 0: |
| 230 | + return VectoriseStats(skipped=1) |
| 231 | + |
| 232 | + file_hash = hash_file(file_path) |
| 233 | + |
| 234 | + def chunk_to_meta(chunk: Chunk) -> chromadb.Metadata: |
| 235 | + meta: dict[str, int | str] = {"path": file_path, "sha256": file_hash} |
| 236 | + if chunk.start: |
| 237 | + meta["start"] = chunk.start.row |
| 238 | + |
| 239 | + if chunk.end: |
| 240 | + meta["end"] = chunk.end.row |
| 241 | + return meta |
| 242 | + |
| 243 | + max_bs = (await self.get_client()).get_max_batch_size() |
| 244 | + for batch_start_idx in range(0, len(chunks), max_bs): |
| 245 | + batch_chunks = [ |
| 246 | + chunks[i].text |
| 247 | + for i in range( |
| 248 | + batch_start_idx, min(batch_start_idx + max_bs, len(chunks)) |
| 249 | + ) |
| 250 | + ] |
| 251 | + batch_embeddings = embeddings[batch_start_idx : batch_start_idx + max_bs] |
| 252 | + batch_meta = [ |
| 253 | + chunk_to_meta(chunks[i]) |
| 254 | + for i in range( |
| 255 | + batch_start_idx, min(batch_start_idx + max_bs, len(chunks)) |
| 256 | + ) |
| 257 | + ] |
| 258 | + async with self.maybe_lock(): |
| 259 | + await asyncio.to_thread( |
| 260 | + collection.add, |
| 261 | + documents=batch_chunks, |
| 262 | + embeddings=batch_embeddings, |
| 263 | + metadatas=batch_meta, |
| 264 | + ids=[get_uuid() for _ in batch_chunks], |
| 265 | + ) |
| 266 | + return VectoriseStats(add=1) |
| 267 | + |
| 268 | + async def delete(self) -> int: |
| 269 | + return await super().delete() |
| 270 | + |
| 271 | + async def drop( |
| 272 | + self, *, collection_id: str | None = None, collection_path: str | None = None |
| 273 | + ): |
| 274 | + return await super().drop( |
| 275 | + collection_id=collection_id, collection_path=collection_path |
| 276 | + ) |
| 277 | + |
| 278 | + async def get_chunks(self, file_path) -> list[Chunk]: |
| 279 | + return await super().get_chunks(file_path) |
| 280 | + |
| 281 | + async def list_collection_content( |
| 282 | + self, |
| 283 | + *, |
| 284 | + what: Optional[ResultType] = None, |
| 285 | + collection_id: str | None = None, |
| 286 | + collection_path: str | None = None, |
| 287 | + ) -> CollectionContent: |
| 288 | + return CollectionContent(files=[], chunks=[]) |
| 289 | + |
| 290 | + async def list_collections(self) -> Sequence[CollectionInfo]: |
| 291 | + return [] |
0 commit comments