Skip to content

Commit 6d63c7f

Browse files
committed
feat(chroma): a WIP chromadb connector for chroma 1.x
1 parent 6a20a0e commit 6d63c7f

File tree

2 files changed

+295
-0
lines changed

2 files changed

+295
-0
lines changed

src/vectorcode/database/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def get_database_connector(config: Config) -> DatabaseConnectorBase:
2626
from vectorcode.database.chroma0 import ChromaDB0Connector
2727

2828
cls = ChromaDB0Connector
29+
case "ChromaDBConnector":
30+
from vectorcode.database.chroma import ChromaDBConnector
31+
32+
cls = ChromaDBConnector
2933
case _:
3034
raise ValueError(f"Unrecognised database type: {config.db_type}")
3135

src/vectorcode/database/chroma.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
import asyncio
2+
import contextlib
3+
import logging
4+
import os
5+
import socket
6+
import sys
7+
from typing import Any, Literal, Optional, Sequence, cast
8+
from urllib.parse import urlparse
9+
10+
import chromadb
11+
from filelock import AsyncFileLock
12+
13+
from vectorcode.chunking import Chunk, TreeSitterChunker
14+
from vectorcode.cli_utils import Config, LockManager, QueryInclude
15+
from vectorcode.database import DatabaseConnectorBase
16+
from vectorcode.database.chroma_common import convert_chroma_query_results
17+
from vectorcode.database.errors import CollectionNotFoundError
18+
from vectorcode.database.types import (
19+
CollectionContent,
20+
CollectionInfo,
21+
QueryResult,
22+
ResultType,
23+
VectoriseStats,
24+
)
25+
from vectorcode.database.utils import get_collection_id, get_uuid, hash_file
26+
27+
if not chromadb.__version__.startswith("1."):
28+
logging.error(
29+
f"""
30+
Found ChromaDB {chromadb.__version__}, which is incompatible wiht your VectorCode installation. Please install `vectorcode`.
31+
32+
For example:
33+
uv tool install vectorcode
34+
"""
35+
)
36+
sys.exit(1)
37+
38+
39+
from chromadb import Collection
40+
from chromadb.api import ClientAPI
41+
from chromadb.config import APIVersion, Settings
42+
43+
logger = logging.getLogger(name=__name__)
44+
45+
SupportedClientType = Literal["http"] | Literal["persistent"]
46+
47+
_SUPPORTED_CLIENT_TYPE: set[SupportedClientType] = {"http", "persistent"}
48+
49+
_default_settings: dict[str, Any] = {
50+
"db_url": None,
51+
"db_path": os.path.expanduser("~/.local/share/vectorcode/chromadb/"),
52+
"db_log_path": os.path.expanduser("~/.local/share/vectorcode/"),
53+
"db_settings": {},
54+
"hnsw": {"hnsw:M": 64},
55+
}
56+
57+
58+
class ChromaDBConnector(DatabaseConnectorBase):
59+
def __init__(self, configs: Config):
60+
super().__init__(configs)
61+
params = _default_settings.copy()
62+
params.update(self._configs.db_params.copy())
63+
self._configs.db_params = params
64+
65+
self._lock: AsyncFileLock | None = None
66+
self._client: ClientAPI | None = None
67+
self._client_type: SupportedClientType
68+
69+
def _create_client(self) -> ClientAPI:
70+
global _SUPPORTED_CLIENT_TYPE
71+
settings: dict[str, Any] = {"anonymized_telemetry": False}
72+
db_params = self._configs.db_params
73+
settings.update(db_params["db_settings"])
74+
if db_params.get("db_url"):
75+
parsed_url = urlparse(db_params["db_url"])
76+
77+
settings["chroma_server_host"] = settings.get(
78+
"chroma_server_host", parsed_url.hostname or "127.0.0.1"
79+
)
80+
settings["chroma_server_http_port"] = settings.get(
81+
"chroma_server_http_port", parsed_url.port or 8000
82+
)
83+
settings["chroma_server_ssl_enabled"] = settings.get(
84+
"chroma_server_ssl_enabled", parsed_url.scheme == "https"
85+
)
86+
settings["chroma_server_api_default_path"] = settings.get(
87+
"chroma_server_api_default_path", parsed_url.path or APIVersion.V2
88+
)
89+
settings_obj = Settings(**settings)
90+
logger.info(
91+
f"Created chromadb.HttpClient from the following settings: {settings_obj}"
92+
)
93+
self._client = chromadb.HttpClient(
94+
host=parsed_url.hostname,
95+
port=parsed_url.port,
96+
ssl=parsed_url.scheme == "https",
97+
settings=settings_obj,
98+
)
99+
self._client_type = "http"
100+
else:
101+
logger.info(
102+
f"Created chromadb.PersistentClient at `{db_params['db_path']}` from the following settings: {settings}"
103+
)
104+
os.makedirs(db_params["db_path"], exist_ok=True)
105+
self._client = chromadb.PersistentClient(path=db_params["db_path"])
106+
107+
self._client_type = "persistent"
108+
assert self._client_type in _SUPPORTED_CLIENT_TYPE
109+
return self._client
110+
111+
async def get_client(self) -> ClientAPI:
112+
if self._client is None:
113+
self._create_client()
114+
assert self._client is not None
115+
if self._client_type == "persistent":
116+
async with LockManager().get_lock(
117+
self._configs.db_params["db_path"]
118+
) as lock:
119+
self._lock = lock
120+
return self._client
121+
122+
@contextlib.asynccontextmanager
123+
async def maybe_lock(self):
124+
"""
125+
Acquire a file (dir) lock if using persistent client.
126+
"""
127+
if self._lock is not None:
128+
await self._lock.acquire()
129+
yield
130+
if self._lock is not None:
131+
await self._lock.release()
132+
133+
async def _create_or_get_collection(
134+
self, collection_path: str, allow_create: bool = False
135+
) -> Collection:
136+
"""
137+
This method should be used by ChromaDB methods that are expected to **create a collection when not found**.
138+
For other methods, just use `client.get_collection` and let it fail if the collection doesn't exist.
139+
"""
140+
141+
collection_meta: dict[str, str | int] = {
142+
"path": os.path.abspath(str(self._configs.project_root)),
143+
"hostname": socket.gethostname(),
144+
"created-by": "VectorCode",
145+
"username": os.environ.get(
146+
"USER", os.environ.get("USERNAME", "DEFAULT_USER")
147+
),
148+
"embedding_function": self._configs.embedding_function,
149+
}
150+
db_params = self._configs.db_params
151+
user_hnsw = db_params.get("hnsw", {})
152+
for key in user_hnsw.keys():
153+
meta_field_name: str = key
154+
if not meta_field_name.startswith("hnsw:"):
155+
meta_field_name = f"hnsw:{meta_field_name}"
156+
if user_hnsw.get(key) is not None:
157+
collection_meta[meta_field_name] = user_hnsw[key]
158+
159+
async with self.maybe_lock():
160+
collection_id = get_collection_id(collection_path)
161+
client = await self.get_client()
162+
if not allow_create:
163+
try:
164+
return client.get_collection(collection_id)
165+
except ValueError as e:
166+
raise CollectionNotFoundError(
167+
f"There's no existing collection for {collection_path} in ChromaDB with the following setup: {self._configs.db_params}"
168+
) from e
169+
col = client.get_or_create_collection(
170+
collection_id, metadata=collection_meta
171+
)
172+
for key in collection_meta.keys():
173+
# validate metadata
174+
assert collection_meta[key] == col.metadata.get(key), (
175+
f"Metadata field {key} mismatch!"
176+
)
177+
178+
return col
179+
180+
async def query(self) -> list[QueryResult]:
181+
collection = await self._create_or_get_collection(
182+
str(self._configs.project_root), False
183+
)
184+
185+
assert self._configs.query is not None
186+
assert len(self._configs.query), "Keywords cannot be empty"
187+
keywords_embeddings = self.get_embedding(self._configs.query)
188+
189+
query_count = self._configs.n_result or (await self.count(ResultType.chunk))
190+
query_filter = None
191+
if len(self._configs.query_exclude):
192+
query_filter = cast(
193+
chromadb.Where, {"path": {"$nin": list(self._configs.query_exclude)}}
194+
)
195+
if QueryInclude.chunk in self._configs.include:
196+
if query_filter is None:
197+
query_filter = cast(chromadb.Where, {"start": {"$gte": 0}})
198+
else:
199+
query_filter = cast(
200+
chromadb.Where,
201+
{"$and": [query_filter.copy(), {"start": {"$gte": 0}}]},
202+
)
203+
204+
async with self.maybe_lock():
205+
raw_result = await asyncio.to_thread(
206+
collection.query,
207+
include=[
208+
"metadatas",
209+
"documents",
210+
"distances",
211+
],
212+
query_embeddings=keywords_embeddings,
213+
where=query_filter,
214+
n_results=query_count,
215+
)
216+
return convert_chroma_query_results(raw_result, self._configs.query)
217+
218+
async def vectorise(
219+
self, file_path: str, chunker: TreeSitterChunker | None = None
220+
) -> VectoriseStats:
221+
collection_path = str(self._configs.project_root)
222+
collection = await self._create_or_get_collection(
223+
collection_path, allow_create=True
224+
)
225+
chunker = chunker or TreeSitterChunker(self._configs)
226+
227+
chunks = tuple(chunker.chunk(file_path))
228+
embeddings = self.get_embedding(list(i.text for i in chunks))
229+
if len(embeddings) == 0:
230+
return VectoriseStats(skipped=1)
231+
232+
file_hash = hash_file(file_path)
233+
234+
def chunk_to_meta(chunk: Chunk) -> chromadb.Metadata:
235+
meta: dict[str, int | str] = {"path": file_path, "sha256": file_hash}
236+
if chunk.start:
237+
meta["start"] = chunk.start.row
238+
239+
if chunk.end:
240+
meta["end"] = chunk.end.row
241+
return meta
242+
243+
max_bs = (await self.get_client()).get_max_batch_size()
244+
for batch_start_idx in range(0, len(chunks), max_bs):
245+
batch_chunks = [
246+
chunks[i].text
247+
for i in range(
248+
batch_start_idx, min(batch_start_idx + max_bs, len(chunks))
249+
)
250+
]
251+
batch_embeddings = embeddings[batch_start_idx : batch_start_idx + max_bs]
252+
batch_meta = [
253+
chunk_to_meta(chunks[i])
254+
for i in range(
255+
batch_start_idx, min(batch_start_idx + max_bs, len(chunks))
256+
)
257+
]
258+
async with self.maybe_lock():
259+
await asyncio.to_thread(
260+
collection.add,
261+
documents=batch_chunks,
262+
embeddings=batch_embeddings,
263+
metadatas=batch_meta,
264+
ids=[get_uuid() for _ in batch_chunks],
265+
)
266+
return VectoriseStats(add=1)
267+
268+
async def delete(self) -> int:
269+
return await super().delete()
270+
271+
async def drop(
272+
self, *, collection_id: str | None = None, collection_path: str | None = None
273+
):
274+
return await super().drop(
275+
collection_id=collection_id, collection_path=collection_path
276+
)
277+
278+
async def get_chunks(self, file_path) -> list[Chunk]:
279+
return await super().get_chunks(file_path)
280+
281+
async def list_collection_content(
282+
self,
283+
*,
284+
what: Optional[ResultType] = None,
285+
collection_id: str | None = None,
286+
collection_path: str | None = None,
287+
) -> CollectionContent:
288+
return CollectionContent(files=[], chunks=[])
289+
290+
async def list_collections(self) -> Sequence[CollectionInfo]:
291+
return []

0 commit comments

Comments
 (0)