44import os
55import socket
66import sys
7+ from asyncio import Lock
78from typing import Any , Literal , Optional , Sequence , cast
89from urllib .parse import urlparse
910
@@ -84,10 +85,13 @@ def __init__(self, configs: Config):
8485 params ["db_log_path" ] = os .path .expanduser (params ["db_log_path" ])
8586 self ._configs .db_params = params
8687
87- self ._lock : AsyncFileLock | None = None
8888 self ._client : ClientAPI | None = None
8989 self ._client_type : SupportedClientType
9090
91+ # locks for persistent client
92+ self ._file_lock : AsyncFileLock | None = None # inter-process lock
93+ self ._thread_lock : Lock | None = None # inter-thread lock
94+
9195 def _create_client (self ) -> ClientAPI :
9296 global _SUPPORTED_CLIENT_TYPE
9397 settings : dict [str , Any ] = {"anonymized_telemetry" : False }
@@ -113,9 +117,9 @@ def _create_client(self) -> ClientAPI:
113117 f"Created chromadb.HttpClient from the following settings: { settings_obj } "
114118 )
115119 self ._client = chromadb .HttpClient (
116- host = parsed_url . hostname ,
117- port = parsed_url . port ,
118- ssl = parsed_url . scheme == "https" ,
120+ host = settings [ "chroma_server_host" ] ,
121+ port = settings [ "chroma_server_http_port" ] ,
122+ ssl = settings [ "chroma_server_ssl_enabled" ] ,
119123 settings = settings_obj ,
120124 )
121125 self ._client_type = "http"
@@ -138,7 +142,8 @@ async def get_client(self) -> ClientAPI:
138142 async with LockManager ().get_lock (
139143 self ._configs .db_params ["db_path" ]
140144 ) as lock :
141- self ._lock = lock
145+ self ._file_lock = lock
146+ self ._thread_lock = Lock ()
142147 return self ._client
143148
144149 @contextlib .asynccontextmanager
@@ -147,13 +152,17 @@ async def maybe_lock(self):
147152 Acquire a file (dir) lock if using persistent client.
148153 """
149154 locked = False
150- if self ._lock is not None :
151- await self ._lock .acquire ()
155+ if self ._file_lock is not None :
156+ assert self ._thread_lock is not None
157+ await self ._file_lock .acquire ()
158+ await self ._thread_lock .acquire ()
152159 locked = True
153160 yield
154161 if locked :
155- assert self ._lock is not None
156- await self ._lock .release ()
162+ assert self ._thread_lock is not None
163+ assert self ._file_lock is not None
164+ await self ._file_lock .release ()
165+ self ._thread_lock .release ()
157166
158167 async def _create_or_get_collection (
159168 self , collection_path : str , allow_create : bool = False
0 commit comments