Skip to content

Commit 29cb6aa

Browse files
committed
feat(db): Add inter-process and inter-thread locks for ChromaDB connector
1 parent be217d6 commit 29cb6aa

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

src/vectorcode/database/chroma.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import socket
66
import sys
7+
from asyncio import Lock
78
from typing import Any, Literal, Optional, Sequence, cast
89
from urllib.parse import urlparse
910

@@ -84,10 +85,13 @@ def __init__(self, configs: Config):
8485
params["db_log_path"] = os.path.expanduser(params["db_log_path"])
8586
self._configs.db_params = params
8687

87-
self._lock: AsyncFileLock | None = None
8888
self._client: ClientAPI | None = None
8989
self._client_type: SupportedClientType
9090

91+
# locks for persistent client
92+
self._file_lock: AsyncFileLock | None = None # inter-process lock
93+
self._thread_lock: Lock | None = None # inter-thread lock
94+
9195
def _create_client(self) -> ClientAPI:
9296
global _SUPPORTED_CLIENT_TYPE
9397
settings: dict[str, Any] = {"anonymized_telemetry": False}
@@ -113,9 +117,9 @@ def _create_client(self) -> ClientAPI:
113117
f"Created chromadb.HttpClient from the following settings: {settings_obj}"
114118
)
115119
self._client = chromadb.HttpClient(
116-
host=parsed_url.hostname,
117-
port=parsed_url.port,
118-
ssl=parsed_url.scheme == "https",
120+
host=settings["chroma_server_host"],
121+
port=settings["chroma_server_http_port"],
122+
ssl=settings["chroma_server_ssl_enabled"],
119123
settings=settings_obj,
120124
)
121125
self._client_type = "http"
@@ -138,7 +142,8 @@ async def get_client(self) -> ClientAPI:
138142
async with LockManager().get_lock(
139143
self._configs.db_params["db_path"]
140144
) as lock:
141-
self._lock = lock
145+
self._file_lock = lock
146+
self._thread_lock = Lock()
142147
return self._client
143148

144149
@contextlib.asynccontextmanager
@@ -147,13 +152,17 @@ async def maybe_lock(self):
147152
Acquire a file (dir) lock if using persistent client.
148153
"""
149154
locked = False
150-
if self._lock is not None:
151-
await self._lock.acquire()
155+
if self._file_lock is not None:
156+
assert self._thread_lock is not None
157+
await self._file_lock.acquire()
158+
await self._thread_lock.acquire()
152159
locked = True
153160
yield
154161
if locked:
155-
assert self._lock is not None
156-
await self._lock.release()
162+
assert self._thread_lock is not None
163+
assert self._file_lock is not None
164+
await self._file_lock.release()
165+
self._thread_lock.release()
157166

158167
async def _create_or_get_collection(
159168
self, collection_path: str, allow_create: bool = False

tests/database/test_chroma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,4 +700,5 @@ async def test_persistent_client(mock_config):
700700
assert connector._client_type == "persistent"
701701
assert os.path.isfile(os.path.join(tmp_db_dir, "vectorcode.lock"))
702702
async with connector.maybe_lock():
703-
assert connector._lock.is_locked
703+
assert connector._file_lock.is_locked
704+
assert connector._thread_lock.locked()

0 commit comments

Comments
 (0)