Skip to content

Commit bb52bcd

Browse files
committed
fix(cli): Verify chromadb connection by checking openapi title
1 parent c8741c9 commit bb52bcd

File tree

2 files changed

+32
-63
lines changed

2 files changed

+32
-63
lines changed

src/vectorcode/common.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import asyncio
22
import contextlib
33
import hashlib
4+
import json
45
import logging
56
import os
67
import socket
78
import subprocess
89
import sys
10+
import traceback
911
from asyncio.subprocess import Process
1012
from dataclasses import dataclass
1113
from typing import Any, AsyncGenerator, Optional
@@ -46,16 +48,19 @@ async def get_collections(
4648

4749

4850
async def try_server(base_url: str):
49-
for ver in ("v1", "v2"): # v1 for legacy, v2 for latest chromadb.
50-
heartbeat_url = f"{base_url}/api/{ver}/heartbeat"
51-
try:
52-
async with httpx.AsyncClient() as client:
53-
response = await client.get(url=heartbeat_url)
54-
logger.debug(f"Heartbeat {heartbeat_url} returned {response=}")
55-
if response.status_code == 200:
56-
return True
57-
except (httpx.ConnectError, httpx.ConnectTimeout):
58-
pass
51+
openapi_url = f"{base_url}/openapi.json"
52+
try:
53+
async with httpx.AsyncClient() as client:
54+
response = await client.get(url=openapi_url)
55+
logger.debug(f"Fetching openapi.json from {openapi_url}: {response=}")
56+
if response.status_code != 200:
57+
return False
58+
openapi_json = json.loads(response.content.decode())
59+
if openapi_json:
60+
return openapi_json.get("info", {}).get("title", "").lower() == "chroma"
61+
except Exception as e:
62+
logger.info(f"Failed to connect to chromadb at {base_url}")
63+
logger.debug(traceback.format_exception(e))
5964
return False
6065

6166

tests/test_common.py

Lines changed: 17 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -97,59 +97,6 @@ def test_get_embedding_function_init_exception():
9797
)
9898

9999

100-
@pytest.mark.asyncio
101-
async def test_try_server_versions():
102-
# Test successful v1 response
103-
with patch("httpx.AsyncClient") as mock_client:
104-
mock_response = MagicMock()
105-
mock_response.status_code = 200
106-
mock_client.return_value.__aenter__.return_value.get.return_value = (
107-
mock_response
108-
)
109-
assert await try_server("http://localhost:8300") is True
110-
mock_client.return_value.__aenter__.return_value.get.assert_called_once_with(
111-
url="http://localhost:8300/api/v1/heartbeat"
112-
)
113-
114-
# Test fallback to v2 when v1 fails
115-
with patch("httpx.AsyncClient") as mock_client:
116-
mock_response_v1 = MagicMock()
117-
mock_response_v1.status_code = 404
118-
mock_response_v2 = MagicMock()
119-
mock_response_v2.status_code = 200
120-
mock_client.return_value.__aenter__.return_value.get.side_effect = [
121-
mock_response_v1,
122-
mock_response_v2,
123-
]
124-
assert await try_server("http://localhost:8300") is True
125-
assert mock_client.return_value.__aenter__.return_value.get.call_count == 2
126-
127-
# Test both versions fail
128-
with patch("httpx.AsyncClient") as mock_client:
129-
mock_response_v1 = MagicMock()
130-
mock_response_v1.status_code = 404
131-
mock_response_v2 = MagicMock()
132-
mock_response_v2.status_code = 500
133-
mock_client.return_value.__aenter__.return_value.get.side_effect = [
134-
mock_response_v1,
135-
mock_response_v2,
136-
]
137-
assert await try_server("http://localhost:8300") is False
138-
139-
# Test connection error cases
140-
with patch("httpx.AsyncClient") as mock_client:
141-
mock_client.return_value.__aenter__.return_value.get.side_effect = (
142-
httpx.ConnectError("Cannot connect")
143-
)
144-
assert await try_server("http://localhost:8300") is False
145-
146-
with patch("httpx.AsyncClient") as mock_client:
147-
mock_client.return_value.__aenter__.return_value.get.side_effect = (
148-
httpx.ConnectTimeout("Connection timeout")
149-
)
150-
assert await try_server("http://localhost:8300") is False
151-
152-
153100
def test_verify_ef():
154101
# Mocking AsyncCollection and Config
155102
mock_collection = MagicMock()
@@ -190,10 +137,18 @@ async def test_try_server_mocked(mock_socket):
190137
with patch("httpx.AsyncClient") as mock_client:
191138
mock_response = MagicMock()
192139
mock_response.status_code = 200
140+
mock_response.content = b'{"info":{"title": "Chroma"}}'
193141
mock_client.return_value.__aenter__.return_value.get.return_value = (
194142
mock_response
195143
)
196144
assert await try_server("http://localhost:8000") is True
145+
with patch("httpx.AsyncClient") as mock_client:
146+
mock_response = MagicMock()
147+
mock_response.status_code = 404
148+
mock_client.return_value.__aenter__.return_value.get.return_value = (
149+
mock_response
150+
)
151+
assert await try_server("http://localhost:8000") is False
197152

198153
# Mocking httpx.AsyncClient to raise a ConnectError
199154
with patch("httpx.AsyncClient") as mock_client:
@@ -202,6 +157,15 @@ async def test_try_server_mocked(mock_socket):
202157
)
203158
assert await try_server("http://localhost:8000") is False
204159

160+
with patch("httpx.AsyncClient") as mock_client:
161+
mock_response = MagicMock()
162+
mock_response.status_code = 200
163+
mock_response.content = b'{"info":{"title": "Dummy"}}'
164+
mock_client.return_value.__aenter__.return_value.get.return_value = (
165+
mock_response
166+
)
167+
assert await try_server("http://localhost:8000") is False
168+
205169
# Mocking httpx.AsyncClient to raise a ConnectTimeout
206170
with patch("httpx.AsyncClient") as mock_client:
207171
mock_client.return_value.__aenter__.return_value.get.side_effect = (

0 commit comments

Comments
 (0)