Skip to content

Commit db10d56

Browse files
committed
refactor(chroma0): extract some stuff that can be reused by chromadb (1.x)
1 parent d84fa81 commit db10d56

File tree

3 files changed

+105
-96
lines changed

3 files changed

+105
-96
lines changed

src/vectorcode/database/chroma0.py

Lines changed: 27 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,14 @@
99
import sys
1010
from asyncio.subprocess import Process
1111
from dataclasses import dataclass
12-
from typing import Any, Optional, Sequence, cast
12+
from typing import Any, Optional, cast
1313
from urllib.parse import urlparse
1414

1515
import chromadb
16-
17-
if not chromadb.__version__.startswith("0.6.3"): # pragma: nocover
18-
logging.error(
19-
f"""
20-
Found ChromaDB {chromadb.__version__}, which is incompatible wiht your VectorCode installation. Please install vectorcode[chroma0].
21-
22-
For example:
23-
uv tool install vectorcode[chroma0]
24-
"""
25-
)
26-
sys.exit(1)
27-
2816
import httpx
2917
from chromadb.api import AsyncClientAPI
3018
from chromadb.api.models.AsyncCollection import AsyncCollection
31-
from chromadb.api.types import IncludeEnum, QueryResult
3219
from chromadb.config import APIVersion, Settings
33-
from chromadb.errors import InvalidCollectionException
3420
from tree_sitter import Point
3521

3622
from vectorcode.chunking import Chunk, TreeSitterChunker
@@ -41,8 +27,8 @@
4127
expand_globs,
4228
expand_path,
4329
)
44-
from vectorcode.database import types
4530
from vectorcode.database.base import DatabaseConnectorBase
31+
from vectorcode.database.chroma_common import convert_chroma_query_results
4632
from vectorcode.database.errors import CollectionNotFoundError
4733
from vectorcode.database.types import (
4834
CollectionContent,
@@ -56,41 +42,6 @@
5642
_logger = logging.getLogger(name=__name__)
5743

5844

59-
def _convert_chroma_query_results(
60-
chroma_result: QueryResult, queries: Sequence[str]
61-
) -> list[types.QueryResult]:
62-
"""Convert chromadb query result to in-house query results"""
63-
assert chroma_result["documents"] is not None
64-
assert chroma_result["distances"] is not None
65-
assert chroma_result["metadatas"] is not None
66-
assert chroma_result["ids"] is not None
67-
68-
chroma_results_list: list[types.QueryResult] = []
69-
for q_i in range(len(queries)):
70-
q = queries[q_i]
71-
documents = chroma_result["documents"][q_i]
72-
distances = chroma_result["distances"][q_i]
73-
metadatas = chroma_result["metadatas"][q_i]
74-
ids = chroma_result["ids"][q_i]
75-
for doc, dist, meta, _id in zip(documents, distances, metadatas, ids):
76-
chunk = Chunk(text=doc, id=_id)
77-
if meta.get("start"):
78-
chunk.start = Point(int(meta.get("start", 0)), 0)
79-
if meta.get("end"):
80-
chunk.end = Point(int(meta.get("end", 0)), 0)
81-
if meta.get("path"):
82-
chunk.path = str(meta["path"])
83-
chroma_results_list.append(
84-
types.QueryResult(
85-
chunk=chunk,
86-
path=str(meta.get("path", "")),
87-
query=(q,),
88-
scores=(-dist,),
89-
)
90-
)
91-
return chroma_results_list
92-
93-
9445
async def _try_server(base_url: str):
9546
for ver in ("v1", "v2"): # v1 for legacy, v2 for latest chromadb.
9647
heartbeat_url = f"{base_url}/api/{ver}/heartbeat"
@@ -173,6 +124,16 @@ class _Chroma0ClientManager:
173124
__clients: dict[str, _Chroma0ClientModel]
174125

175126
def __new__(cls) -> "_Chroma0ClientManager":
127+
if not chromadb.__version__.startswith("0.6.3"): # pragma: nocover
128+
_logger.error(
129+
f"""
130+
Found ChromaDB {chromadb.__version__}, which is incompatible with your VectorCode installation. Please install vectorcode[chroma0].
131+
132+
For example:
133+
uv tool install vectorcode[chroma0]
134+
"""
135+
)
136+
sys.exit(1)
176137
if cls.singleton is None:
177138
cls.singleton = super().__new__(cls)
178139
cls.singleton.__clients = {}
@@ -296,7 +257,7 @@ async def query(self):
296257
)
297258

298259
collection_path = str(self._configs.project_root)
299-
collection: AsyncCollection = await self._create_or_get_collection(
260+
collection: AsyncCollection = await self._create_or_get_async_collection(
300261
collection_path=collection_path, allow_create=False
301262
)
302263
query_count = self._configs.n_result or (await self.count(ResultType.chunk))
@@ -316,16 +277,16 @@ async def query(self):
316277
query_result = await collection.query(
317278
query_embeddings=keywords_embeddings,
318279
include=[
319-
IncludeEnum.metadatas,
320-
IncludeEnum.documents,
321-
IncludeEnum.distances,
280+
"metadatas",
281+
"documents",
282+
"distances",
322283
],
323284
n_results=query_count,
324285
where=query_filter,
325286
)
326-
return _convert_chroma_query_results(query_result, self._configs.query)
287+
return convert_chroma_query_results(query_result, self._configs.query)
327288

328-
async def _create_or_get_collection(
289+
async def _create_or_get_async_collection(
329290
self, collection_path: str, allow_create: bool = False
330291
) -> AsyncCollection:
331292
"""
@@ -354,6 +315,8 @@ async def _create_or_get_collection(
354315
async with _Chroma0ClientManager().get_client(self._configs, True) as client:
355316
collection_id = get_collection_id(collection_path)
356317
if not allow_create:
318+
from chromadb.errors import InvalidCollectionException
319+
357320
try:
358321
return await client.get_collection(collection_id)
359322
except (InvalidCollectionException, ValueError) as e:
@@ -377,7 +340,7 @@ async def vectorise(
377340
chunker: TreeSitterChunker | None = None,
378341
) -> VectoriseStats:
379342
collection_path = str(self._configs.project_root)
380-
collection = await self._create_or_get_collection(
343+
collection = await self._create_or_get_async_collection(
381344
collection_path, allow_create=True
382345
)
383346
chunker = chunker or TreeSitterChunker(self._configs)
@@ -461,7 +424,7 @@ async def list_collection_content(
461424
"""
462425
if collection_id is None:
463426
collection_path = str(collection_path or self._configs.project_root)
464-
collection = await self._create_or_get_collection((collection_path))
427+
collection = await self._create_or_get_async_collection((collection_path))
465428
else:
466429
async with _Chroma0ClientManager().get_client(
467430
self._configs, False
@@ -470,8 +433,8 @@ async def list_collection_content(
470433
content = CollectionContent()
471434
raw_content = await collection.get(
472435
include=[
473-
IncludeEnum.metadatas,
474-
IncludeEnum.documents,
436+
"metadatas",
437+
"documents",
475438
]
476439
)
477440
metadatas = raw_content.get("metadatas", [])
@@ -510,7 +473,7 @@ async def list_collection_content(
510473

511474
async def delete(self) -> int:
512475
collection_path = str(self._configs.project_root)
513-
collection = await self._create_or_get_collection(collection_path, False)
476+
collection = await self._create_or_get_async_collection(collection_path, False)
514477
rm_paths = self._configs.rm_paths
515478
if isinstance(rm_paths, str):
516479
rm_paths = [rm_paths]
@@ -551,7 +514,7 @@ async def drop(self, *, collection_id=None, collection_path=None):
551514
async def get_chunks(self, file_path) -> list[Chunk]:
552515
file_path = os.path.abspath(file_path)
553516
try:
554-
collection = await self._create_or_get_collection(
517+
collection = await self._create_or_get_async_collection(
555518
collection_path=str(self._configs.project_root), allow_create=False
556519
)
557520
except CollectionNotFoundError:
@@ -564,7 +527,7 @@ async def get_chunks(self, file_path) -> list[Chunk]:
564527

565528
raw_results = await collection.get(
566529
where={"path": file_path},
567-
include=[IncludeEnum.metadatas, IncludeEnum.documents],
530+
include=["metadatas", "documents"],
568531
)
569532
assert raw_results["metadatas"] is not None
570533
assert raw_results["documents"] is not None
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Sequence, cast
2+
3+
from chromadb.api.types import QueryResult as ChromaQueryResult
4+
from tree_sitter import Point
5+
6+
from vectorcode.chunking import Chunk
7+
from vectorcode.database import types
8+
9+
10+
def convert_chroma_query_results(
11+
chroma_result: ChromaQueryResult, queries: Sequence[str]
12+
) -> list[types.QueryResult]:
13+
"""Convert chromadb query result to in-house query results"""
14+
assert chroma_result["documents"] is not None
15+
assert chroma_result["distances"] is not None
16+
assert chroma_result["metadatas"] is not None
17+
assert chroma_result["ids"] is not None
18+
19+
chroma_results_list: list[types.QueryResult] = []
20+
for q_i in range(len(queries)):
21+
q = queries[q_i]
22+
documents = chroma_result["documents"][q_i]
23+
distances = chroma_result["distances"][q_i]
24+
metadatas = chroma_result["metadatas"][q_i]
25+
ids = chroma_result["ids"][q_i]
26+
for doc, dist, meta, _id in zip(documents, distances, metadatas, ids):
27+
chunk = Chunk(text=doc, id=_id)
28+
if meta.get("start"):
29+
chunk.start = Point(cast(int, meta.get("start", 0)), 0)
30+
if meta.get("end"):
31+
chunk.end = Point(cast(int, meta.get("end", 0)), 0)
32+
if meta.get("path"):
33+
chunk.path = str(meta["path"])
34+
chroma_results_list.append(
35+
types.QueryResult(
36+
chunk=chunk,
37+
path=str(meta.get("path", "")),
38+
query=(q,),
39+
scores=(-dist,),
40+
)
41+
)
42+
return chroma_results_list

0 commit comments

Comments
 (0)