Skip to content

Commit dfff649

Browse files
authored
refactor(cli): use db_url to replace host and port. (#143)
* refactor(cli): use `db_url` to replace `host` and `port`. * feat(cli): build chroma url from user-configured host and port.
1 parent dbe1abc commit dfff649

File tree

7 files changed

+121
-105
lines changed

7 files changed

+121
-105
lines changed

docs/cli.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,9 @@ The JSON configuration file may hold the following values:
261261
Then the embedding function object will be initialised as
262262
`OllamaEmbeddingFunction(url="http://127.0.0.1:11434/api/embeddings",
263263
model_name="nomic-embed-text")`. Default: `{}`;
264-
- `host` and `port`: string and integer, Chromadb server host and port. VectorCode will start an
264+
- `db_url`: string, the url that points to the Chromadb server. VectorCode will start an
265265
HTTP server for Chromadb at a randomly picked free port on `localhost` if your
266-
configured `host:port` is not accessible. This allows the use of `AsyncHttpClient`.
267-
Default: `127.0.0.1:8000`;
266+
configured `http://host:port` is not accessible. Default: `http://127.0.0.1:8000`;
268267
- `db_path`: string, Path to local persistent database. This is where the files for
269268
your database will be stored. Default: `~/.local/share/vectorcode/chromadb/`;
270269
- `db_log_path`: string, path to the _directory_ where the built-in chromadb

src/vectorcode/cli_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ class Config:
7373
files: list[PathLike] = field(default_factory=list)
7474
project_root: Optional[PathLike] = None
7575
query: Optional[list[str]] = None
76-
host: str = "127.0.0.1"
77-
port: int = 8000
76+
db_url: str = "http://127.0.0.1:8000"
7877
embedding_function: str = "SentenceTransformerEmbeddingFunction" # This should fallback to whatever the default is.
7978
embedding_params: dict[str, Any] = field(default_factory=(lambda: {}))
8079
n_result: int = 1
@@ -105,8 +104,21 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
105104
"""
106105
default_config = Config()
107106
db_path = config_dict.get("db_path")
108-
host = config_dict.get("host") or "localhost"
109-
port = config_dict.get("port") or 8000
107+
db_url = config_dict.get("db_url")
108+
if db_url is None:
109+
host = config_dict.get("host")
110+
port = config_dict.get("port")
111+
if host is not None or port is not None:
112+
# TODO: deprecate `host` and `port` in 0.7.0
113+
host = host or "127.0.0.1"
114+
port = port or 8000
115+
db_url = f"http://{host}:{port}"
116+
logger.warning(
117+
f'"host" and "port" are deprecated and will be removed in 0.7.0. Use "db_url" (eg. {db_url}).'
118+
)
119+
else:
120+
db_url = "http://127.0.0.1:8000"
121+
110122
if db_path is None:
111123
db_path = os.path.expanduser("~/.local/share/vectorcode/chromadb/")
112124
elif not os.path.isdir(db_path):
@@ -121,8 +133,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
121133
"embedding_params": config_dict.get(
122134
"embedding_params", default_config.embedding_params
123135
),
124-
"host": host,
125-
"port": port,
136+
"db_url": db_url,
126137
"db_path": db_path,
127138
"db_log_path": os.path.expanduser(
128139
config_dict.get("db_log_path", default_config.db_log_path)

src/vectorcode/common.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import socket
66
import subprocess
77
import sys
8-
from typing import AsyncGenerator
8+
from typing import Any, AsyncGenerator
9+
from urllib.parse import urlparse
910

1011
import chromadb
1112
import httpx
1213
from chromadb.api import AsyncClientAPI
1314
from chromadb.api.models.AsyncCollection import AsyncCollection
14-
from chromadb.config import Settings
15+
from chromadb.config import APIVersion, Settings
1516
from chromadb.utils import embedding_functions
1617

1718
from vectorcode.cli_utils import Config, expand_path
@@ -40,26 +41,26 @@ async def get_collections(
4041
yield collection
4142

4243

43-
async def try_server(host: str, port: int):
44+
async def try_server(base_url: str):
4445
for ver in ("v1", "v2"): # v1 for legacy, v2 for latest chromadb.
45-
url = f"http://{host}:{port}/api/{ver}/heartbeat"
46+
heartbeat_url = f"{base_url}/api/{ver}/heartbeat"
4647
try:
4748
async with httpx.AsyncClient() as client:
48-
response = await client.get(url=url)
49-
logger.debug(f"Heartbeat {url} returned {response=}")
49+
response = await client.get(url=heartbeat_url)
50+
logger.debug(f"Heartbeat {heartbeat_url} returned {response=}")
5051
if response.status_code == 200:
5152
return True
5253
except (httpx.ConnectError, httpx.ConnectTimeout):
5354
pass
5455
return False
5556

5657

57-
async def wait_for_server(host, port, timeout=10):
58+
async def wait_for_server(url: str, timeout=10):
5859
# Poll the server until it's ready or timeout is reached
5960

6061
start_time = asyncio.get_event_loop().time()
6162
while True:
62-
if await try_server(host, port):
63+
if await try_server(url):
6364
return
6465

6566
if asyncio.get_event_loop().time() - start_time > timeout:
@@ -82,10 +83,8 @@ async def start_server(configs: Config):
8283
env = os.environ.copy()
8384
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
8485
s.bind(("", 0)) # OS selects a free ephemeral port
85-
configs.port = int(s.getsockname()[1])
86-
logger.warning(
87-
f"Starting bundled ChromaDB server at {configs.host}:{configs.port}."
88-
)
86+
port = int(s.getsockname()[1])
87+
logger.warning(f"Starting bundled ChromaDB server at http://127.0.0.1:{port}.")
8988
env.update({"ANONYMIZED_TELEMETRY": "False"})
9089
process = await asyncio.create_subprocess_exec(
9190
sys.executable,
@@ -95,7 +94,7 @@ async def start_server(configs: Config):
9594
"--host",
9695
"localhost",
9796
"--port",
98-
str(configs.port),
97+
str(port),
9998
"--path",
10099
db_path,
101100
"--log-path",
@@ -105,28 +104,32 @@ async def start_server(configs: Config):
105104
env=env,
106105
)
107106

108-
await wait_for_server(configs.host, configs.port)
107+
await wait_for_server(f"http://127.0.0.1:{port}")
109108
return process
110109

111110

112-
__CLIENT_CACHE: dict[tuple[str, int], AsyncClientAPI] = {}
111+
__CLIENT_CACHE: dict[str, AsyncClientAPI] = {}
113112

114113

115114
async def get_client(configs: Config) -> AsyncClientAPI:
116-
assert configs.host is not None
117-
assert configs.port is not None
118-
client_entry = (configs.host, configs.port)
115+
client_entry = configs.db_url
119116
if __CLIENT_CACHE.get(client_entry) is None:
120-
settings = {"anonymized_telemetry": False}
117+
settings: dict[str, Any] = {"anonymized_telemetry": False}
121118
if isinstance(configs.db_settings, dict):
122119
valid_settings = {
123120
k: v for k, v in configs.db_settings.items() if k in Settings.__fields__
124121
}
125122
settings.update(valid_settings)
123+
parsed_url = urlparse(configs.db_url)
124+
settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1"
125+
settings["chroma_server_http_port"] = parsed_url.port or 8000
126+
settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https"
127+
settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2
128+
settings_obj = Settings(**settings)
126129
__CLIENT_CACHE[client_entry] = await chromadb.AsyncHttpClient(
127-
host=configs.host or "localhost",
128-
port=configs.port or 8000,
129-
settings=Settings(**settings),
130+
settings=settings_obj,
131+
host=str(settings_obj.chroma_server_host),
132+
port=int(settings_obj.chroma_server_http_port or 8000),
130133
)
131134
return __CLIENT_CACHE[client_entry]
132135

src/vectorcode/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def async_main():
7474
from vectorcode.common import start_server, try_server
7575

7676
server_process = None
77-
if not await try_server(final_configs.host, final_configs.port):
77+
if not await try_server(final_configs.db_url):
7878
server_process = await start_server(final_configs)
7979

8080
if final_configs.pipe:

tests/subcommands/test_vectorise.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,7 @@ def test_load_files_from_include_no_files(mock_check_tree_files, mock_isfile, tm
268268
@pytest.mark.asyncio
269269
async def test_vectorise(capsys):
270270
configs = Config(
271-
host="test_host",
272-
port=1234,
271+
db_url="http://test_host:1234",
273272
db_path="test_db",
274273
embedding_function="SentenceTransformerEmbeddingFunction",
275274
embedding_params={},
@@ -330,8 +329,7 @@ async def test_vectorise(capsys):
330329
@pytest.mark.asyncio
331330
async def test_vectorise_cancelled():
332331
configs = Config(
333-
host="test_host",
334-
port=1234,
332+
db_url="http://test_host:1234",
335333
db_path="test_db",
336334
embedding_function="SentenceTransformerEmbeddingFunction",
337335
embedding_params={},
@@ -373,8 +371,7 @@ async def mock_chunked_add(*args, **kwargs):
373371
@pytest.mark.asyncio
374372
async def test_vectorise_orphaned_files():
375373
configs = Config(
376-
host="test_host",
377-
port=1234,
374+
db_url="http://test_host:1234",
378375
db_path="test_db",
379376
embedding_function="SentenceTransformerEmbeddingFunction",
380377
embedding_params={},
@@ -443,8 +440,7 @@ def is_file_side_effect(path):
443440
@pytest.mark.asyncio
444441
async def test_vectorise_collection_index_error():
445442
configs = Config(
446-
host="test_host",
447-
port=1234,
443+
db_url="http://test_host:1234",
448444
db_path="test_db",
449445
embedding_function="SentenceTransformerEmbeddingFunction",
450446
embedding_params={},
@@ -470,8 +466,7 @@ async def test_vectorise_collection_index_error():
470466
@pytest.mark.asyncio
471467
async def test_vectorise_verify_ef_false():
472468
configs = Config(
473-
host="test_host",
474-
port=1234,
469+
db_url="http://test_host:1234",
475470
db_path="test_db",
476471
embedding_function="SentenceTransformerEmbeddingFunction",
477472
embedding_params={},
@@ -500,8 +495,7 @@ async def test_vectorise_verify_ef_false():
500495
@pytest.mark.asyncio
501496
async def test_vectorise_gitignore():
502497
configs = Config(
503-
host="test_host",
504-
port=1234,
498+
db_url="http://test_host:1234",
505499
db_path="test_db",
506500
embedding_function="SentenceTransformerEmbeddingFunction",
507501
embedding_params={},
@@ -548,8 +542,7 @@ async def test_vectorise_exclude_file(tmpdir):
548542
exclude_file.write("excluded_file.py\n")
549543

550544
configs = Config(
551-
host="test_host",
552-
port=1234,
545+
db_url="http://test_host:1234",
553546
db_path="test_db",
554547
embedding_function="SentenceTransformerEmbeddingFunction",
555548
embedding_params={},

tests/test_cli_utils.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ async def test_config_import_from():
2828
os.makedirs(db_path, exist_ok=True)
2929
config_dict: Dict[str, Any] = {
3030
"db_path": db_path,
31-
"host": "test_host",
32-
"port": 1234,
31+
"db_url": "http://test_host:1234",
3332
"embedding_function": "TestEmbedding",
3433
"embedding_params": {"param1": "value1"},
3534
"chunk_size": 512,
@@ -42,8 +41,7 @@ async def test_config_import_from():
4241
config = await Config.import_from(config_dict)
4342
assert config.db_path == db_path
4443
assert config.db_log_path == os.path.expanduser("~/.local/share/vectorcode/")
45-
assert config.host == "test_host"
46-
assert config.port == 1234
44+
assert config.db_url == "http://test_host:1234"
4745
assert config.embedding_function == "TestEmbedding"
4846
assert config.embedding_params == {"param1": "value1"}
4947
assert config.chunk_size == 512
@@ -54,6 +52,14 @@ async def test_config_import_from():
5452
assert config.db_settings == {"db_setting1": "db_value1"}
5553

5654

55+
@pytest.mark.asyncio
56+
async def test_config_import_from_fallback_host_port():
57+
conf = {"host": "test_host"}
58+
assert (await Config.import_from(conf)).db_url == "http://test_host:8000"
59+
conf = {"port": 114514}
60+
assert (await Config.import_from(conf)).db_url == "http://127.0.0.1:114514"
61+
62+
5763
@pytest.mark.asyncio
5864
async def test_config_import_from_invalid_path():
5965
config_dict: Dict[str, Any] = {"db_path": "/path/does/not/exist"}
@@ -75,22 +81,20 @@ async def test_config_import_from_db_path_is_file():
7581

7682
@pytest.mark.asyncio
7783
async def test_config_merge_from():
78-
config1 = Config(host="host1", port=8001, n_result=5)
79-
config2 = Config(host="host2", port=None, query=["test"])
84+
config1 = Config(db_url="http://host1:8001", n_result=5)
85+
config2 = Config(db_url="http://host2:8002", query=["test"])
8086
merged_config = await config1.merge_from(config2)
81-
assert merged_config.host == "host2"
82-
assert merged_config.port == 8001 # port from config1 should be retained
87+
assert merged_config.db_url == "http://host2:8002"
8388
assert merged_config.n_result == 5
8489
assert merged_config.query == ["test"]
8590

8691

8792
@pytest.mark.asyncio
8893
async def test_config_merge_from_new_fields():
89-
config1 = Config(host="host1", port=8001)
94+
config1 = Config(db_url="http://host1:8001")
9095
config2 = Config(query=["test"], n_result=10, recursive=True)
9196
merged_config = await config1.merge_from(config2)
92-
assert merged_config.host == "host1"
93-
assert merged_config.port == 8001
97+
assert merged_config.db_url == "http://host1:8001"
9498
assert merged_config.query == ["test"]
9599
assert merged_config.n_result == 10
96100
assert merged_config.recursive
@@ -104,8 +108,7 @@ async def test_config_import_from_missing_keys():
104108
# Assert that default values are used
105109
assert config.embedding_function == "SentenceTransformerEmbeddingFunction"
106110
assert config.embedding_params == {}
107-
assert config.host == "localhost"
108-
assert config.port == 8000
111+
assert config.db_url == "http://127.0.0.1:8000"
109112
assert config.db_path == os.path.expanduser("~/.local/share/vectorcode/chromadb/")
110113
assert config.chunk_size == 2500
111114
assert config.overlap_ratio == 0.2
@@ -318,7 +321,7 @@ def test_find_project_root():
318321
async def test_get_project_config_no_local_config():
319322
with tempfile.TemporaryDirectory() as temp_dir:
320323
config = await get_project_config(temp_dir)
321-
assert config.host in {"127.0.0.1", "localhost"}
324+
assert config.chunk_size == Config().chunk_size, "Should load default value."
322325

323326

324327
@pytest.mark.asyncio
@@ -394,36 +397,28 @@ async def test_parse_cli_args_vectorise_no_files():
394397

395398
@pytest.mark.asyncio
396399
async def test_get_project_config_local_config(tmp_path):
397-
# Create a temporary directory and a .vectorcode subdirectory
398400
project_root = tmp_path / "project"
399401
vectorcode_dir = project_root / ".vectorcode"
400402
vectorcode_dir.mkdir(parents=True)
401403

402-
# Create a config.json file inside .vectorcode with some custom settings
403404
config_file = vectorcode_dir / "config.json"
404-
config_file.write_text('{"host": "test_host", "port": 9999}')
405+
config_file.write_text('{"db_url": "http://test_host:9999" }')
405406

406-
# Call get_project_config and check if it returns the custom settings
407407
config = await get_project_config(project_root)
408-
assert config.host == "test_host"
409-
assert config.port == 9999
408+
assert config.db_url == "http://test_host:9999"
410409

411410

412411
@pytest.mark.asyncio
413412
async def test_get_project_config_local_config_json5(tmp_path):
414-
# Create a temporary directory and a .vectorcode subdirectory
415413
project_root = tmp_path / "project"
416414
vectorcode_dir = project_root / ".vectorcode"
417415
vectorcode_dir.mkdir(parents=True)
418416

419-
# Create a config.json file inside .vectorcode with some custom settings
420417
config_file = vectorcode_dir / "config.json5"
421-
config_file.write_text('{"host": "test_host", "port": 9999}')
418+
config_file.write_text('{"db_url": "http://test_host:9999" }')
422419

423-
# Call get_project_config and check if it returns the custom settings
424420
config = await get_project_config(project_root)
425-
assert config.host == "test_host"
426-
assert config.port == 9999
421+
assert config.db_url == "http://test_host:9999"
427422

428423

429424
def test_find_project_root_file_input(tmp_path):
@@ -512,11 +507,9 @@ async def test_config_import_from_hnsw():
512507

513508
@pytest.mark.asyncio
514509
async def test_hnsw_config_merge():
515-
config1 = Config(host="host1", port=8001, hnsw={"space": "ip"})
516-
config2 = Config(host="host2", port=None, hnsw={"ef_construction": 200})
510+
config1 = Config(hnsw={"space": "ip"})
511+
config2 = Config(hnsw={"ef_construction": 200})
517512
merged_config = await config1.merge_from(config2)
518-
assert merged_config.host == "host2"
519-
assert merged_config.port == 8001
520513
assert merged_config.hnsw["space"] == "ip"
521514
assert merged_config.hnsw["ef_construction"] == 200
522515

0 commit comments

Comments
 (0)