Skip to content

Commit 47997fa

Browse files
committed
tests(cli): Refactor config and cleanup path handling
1 parent 21657ff commit 47997fa

File tree

2 files changed

+70
-76
lines changed

2 files changed

+70
-76
lines changed

src/vectorcode/cli_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ async def expand_globs(
589589

590590

591591
def cleanup_path(path: str):
592-
if os.path.isabs(path) and os.environ.get("HOME") is not None:
592+
if os.path.isabs(path) and os.environ.get("HOME", "") != "":
593593
return path.replace(os.environ["HOME"], "~")
594594
return path
595595

tests/test_cli_utils.py

Lines changed: 69 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,24 @@ async def test_config_import_from():
3636
os.makedirs(db_path, exist_ok=True)
3737
config_dict: Dict[str, Any] = {
3838
"db_path": db_path,
39-
"db_url": "http://test_host:1234",
39+
"db_params": {"url": "http://test_host:1234"},
4040
"embedding_function": "TestEmbedding",
4141
"embedding_params": {"param1": "value1"},
4242
"chunk_size": 512,
4343
"overlap_ratio": 0.3,
4444
"query_multiplier": 5,
4545
"reranker": "TestReranker",
4646
"reranker_params": {"reranker_param1": "reranker_value1"},
47-
"db_settings": {"db_setting1": "db_value1"},
4847
}
4948
config = await Config.import_from(config_dict)
50-
assert config.db_path == db_path
51-
assert config.db_log_path == os.path.expanduser("~/.local/share/vectorcode/")
52-
assert config.db_url == "http://test_host:1234"
49+
assert isinstance(config.db_params, dict)
5350
assert config.embedding_function == "TestEmbedding"
5451
assert config.embedding_params == {"param1": "value1"}
5552
assert config.chunk_size == 512
5653
assert config.overlap_ratio == 0.3
5754
assert config.query_multiplier == 5
5855
assert config.reranker == "TestReranker"
5956
assert config.reranker_params == {"reranker_param1": "reranker_value1"}
60-
assert config.db_settings == {"db_setting1": "db_value1"}
6157

6258

6359
@pytest.mark.asyncio
@@ -81,20 +77,20 @@ async def test_config_import_from_db_path_is_file():
8177

8278
@pytest.mark.asyncio
8379
async def test_config_merge_from():
84-
config1 = Config(db_url="http://host1:8001", n_result=5)
85-
config2 = Config(db_url="http://host2:8002", query=["test"])
80+
config1 = Config(db_params={"url": "http://host1:8001"}, n_result=5)
81+
config2 = Config(db_params={"url": "http://host2:8002"}, query=["test"])
8682
merged_config = await config1.merge_from(config2)
87-
assert merged_config.db_url == "http://host2:8002"
83+
assert merged_config.db_params["url"] == "http://host2:8002"
8884
assert merged_config.n_result == 5
8985
assert merged_config.query == ["test"]
9086

9187

9288
@pytest.mark.asyncio
9389
async def test_config_merge_from_new_fields():
94-
config1 = Config(db_url="http://host1:8001")
90+
config1 = Config(db_params={"url": "http://host1:8001"})
9591
config2 = Config(query=["test"], n_result=10, recursive=True)
9692
merged_config = await config1.merge_from(config2)
97-
assert merged_config.db_url == "http://host1:8001"
93+
assert merged_config.db_params["url"] == "http://host1:8001"
9894
assert merged_config.query == ["test"]
9995
assert merged_config.n_result == 10
10096
assert merged_config.recursive
@@ -104,18 +100,17 @@ async def test_config_merge_from_new_fields():
104100
async def test_config_import_from_missing_keys():
105101
config_dict: Dict[str, Any] = {} # Empty dictionary, all keys missing
106102
config = await Config.import_from(config_dict)
103+
default_config = Config()
107104

108105
# Assert that default values are used
109-
assert config.embedding_function == "SentenceTransformerEmbeddingFunction"
110-
assert config.embedding_params == {}
111-
assert config.db_url == "http://127.0.0.1:8000"
112-
assert config.db_path == os.path.expanduser("~/.local/share/vectorcode/chromadb/")
113-
assert config.chunk_size == 2500
114-
assert config.overlap_ratio == 0.2
115-
assert config.query_multiplier == -1
116-
assert config.reranker == "NaiveReranker"
117-
assert config.reranker_params == {}
118-
assert config.db_settings is None
106+
assert config.embedding_function == default_config.embedding_function
107+
assert config.embedding_params == default_config.embedding_params
108+
assert config.db_params == default_config.db_params
109+
assert config.chunk_size == default_config.chunk_size
110+
assert config.overlap_ratio == default_config.overlap_ratio
111+
assert config.query_multiplier == default_config.query_multiplier
112+
assert config.reranker == default_config.reranker
113+
assert config.reranker_params == default_config.reranker_params
119114

120115

121116
def test_expand_envs_in_dict():
@@ -133,6 +128,8 @@ def test_expand_envs_in_dict():
133128
expand_envs_in_dict(d)
134129
assert d["key4"] == "$TEST_VAR2" # Should remain unchanged
135130

131+
expand_envs_in_dict(None)
132+
136133
del os.environ["TEST_VAR"] # Clean up the env
137134

138135

@@ -222,12 +219,12 @@ async def test_load_from_default_config():
222219
config_dir,
223220
)
224221
os.makedirs(config_dir, exist_ok=True)
225-
config_content = '{"db_url": "http://default.url:8000"}'
222+
config_content = '{"db_params": {"url": "http://default.url:8000"}}'
226223
with open(config_path, "w") as fin:
227224
fin.write(config_content)
228225

229226
config = await load_config_file()
230-
assert config.db_url == "http://default.url:8000"
227+
assert isinstance(config.db_params, dict)
231228

232229

233230
@pytest.mark.asyncio
@@ -321,6 +318,7 @@ async def test_cli_arg_parser():
321318
def test_query_include_to_header():
322319
assert QueryInclude.path.to_header() == "Path: "
323320
assert QueryInclude.document.to_header() == "Document:\n"
321+
assert QueryInclude.chunk.to_header() == "Chunk: "
324322

325323

326324
def test_find_project_root():
@@ -402,12 +400,12 @@ async def test_parse_cli_args_vectorise_recursive_dir():
402400

403401
@pytest.mark.asyncio
404402
async def test_parse_cli_args_vectorise_recursive_dir_include_hidden():
405-
with patch("sys.argv", ["vectorcode", "vectorise", "-r", "."]):
403+
with patch("sys.argv", ["vectorcode", "vectorise", "-r", ".", "--include-hidden"]):
406404
config = await parse_cli_args()
407405
assert config.action == CliAction.vectorise
408406
assert config.files == ["."]
409407
assert config.recursive is True
410-
assert config.include_hidden is False
408+
assert config.include_hidden is True
411409

412410

413411
@pytest.mark.asyncio
@@ -425,10 +423,10 @@ async def test_get_project_config_local_config(tmp_path):
425423
vectorcode_dir.mkdir(parents=True)
426424

427425
config_file = vectorcode_dir / "config.json"
428-
config_file.write_text('{"db_url": "http://test_host:9999" }')
426+
config_file.write_text('{"db_params": {"url": "http://test_host:9999"} }')
429427

430428
config = await get_project_config(project_root)
431-
assert config.db_url == "http://test_host:9999"
429+
assert isinstance(config.db_params, dict)
432430

433431

434432
@pytest.mark.asyncio
@@ -438,10 +436,10 @@ async def test_get_project_config_local_config_json5(tmp_path):
438436
vectorcode_dir.mkdir(parents=True)
439437

440438
config_file = vectorcode_dir / "config.json5"
441-
config_file.write_text('{"db_url": "http://test_host:9999" }')
439+
config_file.write_text('{"db_params": {"url": "http://test_host:9999"} }')
442440

443441
config = await get_project_config(project_root)
444-
assert config.db_url == "http://test_host:9999"
442+
assert isinstance(config.db_params, dict)
445443

446444

447445
def test_find_project_root_file_input(tmp_path):
@@ -484,9 +482,10 @@ async def test_parse_cli_args_check():
484482

485483
@pytest.mark.asyncio
486484
async def test_parse_cli_args_init():
487-
with patch("sys.argv", ["vectorcode", "init"]):
485+
with patch("sys.argv", ["vectorcode", "init", "--force"]):
488486
config = await parse_cli_args()
489487
assert config.action == CliAction.init
488+
assert config.force is True
490489

491490

492491
@pytest.mark.asyncio
@@ -527,37 +526,15 @@ async def test_parse_cli_args_files():
527526
assert config.rm_paths == ["foo.txt"]
528527

529528

530-
@pytest.mark.asyncio
531-
async def test_config_import_from_hnsw():
532-
with tempfile.TemporaryDirectory() as temp_dir:
533-
db_path = os.path.join(temp_dir, "test_db")
534-
os.makedirs(db_path, exist_ok=True)
535-
config_dict: Dict[str, Any] = {
536-
"hnsw": {"space": "cosine", "ef_construction": 200, "m": 32}
537-
}
538-
config = await Config.import_from(config_dict)
539-
assert config.hnsw["space"] == "cosine"
540-
assert config.hnsw["ef_construction"] == 200
541-
assert config.hnsw["m"] == 32
542-
543-
544-
@pytest.mark.asyncio
545-
async def test_hnsw_config_merge():
546-
config1 = Config(hnsw={"space": "ip"})
547-
config2 = Config(hnsw={"ef_construction": 200})
548-
merged_config = await config1.merge_from(config2)
549-
assert merged_config.hnsw["space"] == "ip"
550-
assert merged_config.hnsw["ef_construction"] == 200
551-
552-
553529
def test_cleanup_path():
554530
home = os.environ.get("HOME")
555-
if home is None:
556-
return
557-
assert cleanup_path(os.path.join(home, "test_path")) == os.path.join(
558-
"~", "test_path"
559-
)
531+
if home:
532+
assert cleanup_path(os.path.join(home, "test_path")) == os.path.join(
533+
"~", "test_path"
534+
)
560535
assert cleanup_path("/etc/dir") == "/etc/dir"
536+
with patch.dict(os.environ, {"HOME": ""}):
537+
assert cleanup_path("/etc/dir") == "/etc/dir"
561538

562539

563540
def test_shtab():
@@ -576,8 +553,10 @@ def test_shtab():
576553
async def test_filelock():
577554
manager = LockManager()
578555
with tempfile.TemporaryDirectory() as tmp_dir:
579-
manager.get_lock(tmp_dir)
556+
lock = manager.get_lock(tmp_dir)
580557
assert os.path.isfile(os.path.join(tmp_dir, "vectorcode.lock"))
558+
# test getting existing lock
559+
assert lock is manager.get_lock(tmp_dir)
581560

582561

583562
def test_specresolver():
@@ -610,23 +589,38 @@ def test_specresolver_builder():
610589
patch("vectorcode.cli_utils.open"),
611590
):
612591
base_dir = os.path.normpath(os.path.join("foo", "bar"))
613-
assert (
614-
os.path.normpath(
615-
SpecResolver.from_path(os.path.join(base_dir, ".gitignore")).base_dir
616-
)
617-
== base_dir
618-
)
592+
assert os.path.normpath(
593+
SpecResolver.from_path(os.path.join(base_dir, ".gitignore")).base_dir
594+
) == os.path.abspath(base_dir)
619595

620-
assert (
621-
os.path.normpath(
622-
SpecResolver.from_path(
623-
os.path.join(base_dir, ".vectorcode", "vectorcode.exclude")
624-
).base_dir
625-
)
626-
== base_dir
627-
)
628596
assert os.path.normpath(
629597
SpecResolver.from_path(
630-
os.path.join(base_dir, "vectorcode", "vectorcode.exclude")
598+
os.path.join(base_dir, ".vectorcode", "vectorcode.exclude")
599+
).base_dir
600+
) == os.path.abspath(base_dir)
601+
assert os.path.normpath(
602+
SpecResolver.from_path(
603+
os.path.join(base_dir, "vectorcode", "vectorcode.exclude"),
604+
project_root=base_dir,
631605
).base_dir
632-
) == os.path.normpath(".")
606+
) == os.path.abspath(base_dir)
607+
with pytest.raises(ValueError):
608+
SpecResolver.from_path("foo/bar")
609+
610+
611+
@pytest.mark.asyncio
612+
async def test_find_project_root_at_root():
613+
with tempfile.TemporaryDirectory() as temp_dir:
614+
os.makedirs(os.path.join(temp_dir, ".git"))
615+
# in a git repo, find_project_root should not go beyond the git root
616+
assert find_project_root(temp_dir, ".git") == temp_dir
617+
assert find_project_root(temp_dir, ".vectorcode") is None
618+
619+
620+
@pytest.mark.asyncio
621+
async def test_find_project_config_dir_at_root():
622+
with tempfile.TemporaryDirectory() as temp_dir:
623+
git_dir = os.path.join(temp_dir, ".git")
624+
os.makedirs(git_dir)
625+
# in a git repo, find_project_root should not go beyond the git root
626+
assert await find_project_config_dir(temp_dir) == git_dir

0 commit comments

Comments
 (0)