Skip to content

Commit e2667e5

Browse files
committed
tests(cli): improve test coverage.
1 parent 18520c4 commit e2667e5

File tree

3 files changed

+52
-5
lines changed

3 files changed

+52
-5
lines changed

src/vectorcode/cli_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
127127
default_config = Config()
128128
db_path = config_dict.get("db_path")
129129

130+
expand_envs_in_dict(config_dict)
130131
if db_path is None:
131132
db_path = os.path.expanduser("~/.local/share/vectorcode/chromadb/")
132133
elif not os.path.isdir(db_path):
@@ -470,7 +471,7 @@ async def parse_cli_args(args: Optional[Sequence[str]] = None):
470471

471472

472473
def expand_envs_in_dict(d: dict):
473-
if not isinstance(d, dict):
474+
if not isinstance(d, dict): # pragma: nocover
474475
return
475476
stack = [d]
476477
while stack:
@@ -485,6 +486,7 @@ def expand_envs_in_dict(d: dict):
485486
async def load_config_file(path: str | Path | None = None) -> Config:
486487
"""
487488
Load config object by merging the project-local and the global config files.
489+
`path` can be a _file path_ or a _project-root_ path.
488490
489491
Raises `ValueError` if the config file is not a valid json dictionary.
490492
"""
@@ -554,13 +556,12 @@ def find_project_root(
554556
start_from = start_from.parent
555557

556558

557-
async def get_project_config(project_root: Union[str, Path]) -> Config:
559+
async def get_project_config(project_root: str | Path) -> Config:
558560
"""
559561
Load config file for `project_root`.
560562
Fallback to global config, and then default config.
561563
"""
562-
if not os.path.isabs(project_root):
563-
project_root = os.path.abspath(project_root)
564+
project_root = os.path.abspath(os.path.expanduser(project_root))
564565
exts = ("json5", "json")
565566
config = None
566567
for ext in exts:

src/vectorcode/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ async def async_main():
2424
if cli_args.no_stderr:
2525
sys.stderr = open(os.devnull, "w")
2626

27-
if cli_args.debug:
27+
if cli_args.debug: # pragma: nocover
2828
from vectorcode import debugging
2929

3030
debugging.enable()

tests/test_cli_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,52 @@ async def test_load_config_file_invalid_json():
207207
await load_config_file(config_path)
208208

209209

210+
@pytest.mark.asyncio
211+
async def test_load_config_file_merging():
212+
with tempfile.TemporaryDirectory() as dummy_home:
213+
global_config_dir = os.path.join(dummy_home, ".config", "vectorcode")
214+
os.makedirs(global_config_dir, exist_ok=True)
215+
with open(os.path.join(global_config_dir, "config.json"), mode="w") as fin:
216+
fin.writelines(['{"embedding_function": "DummyEmbeddingFunction"}'])
217+
218+
with tempfile.TemporaryDirectory(dir=dummy_home) as proj_root:
219+
os.makedirs(os.path.join(proj_root, ".vectorcode"), exist_ok=True)
220+
with open(
221+
os.path.join(proj_root, ".vectorcode", "config.json"), mode="w"
222+
) as fin:
223+
fin.writelines(
224+
['{"embedding_function": "AnotherDummyEmbeddingFunction"}']
225+
)
226+
227+
with patch(
228+
"vectorcode.cli_utils.GLOBAL_CONFIG_DIR", new=str(global_config_dir)
229+
):
230+
assert (
231+
await load_config_file()
232+
).embedding_function == "DummyEmbeddingFunction"
233+
assert (
234+
await load_config_file(proj_root)
235+
).embedding_function == "AnotherDummyEmbeddingFunction"
236+
237+
238+
@pytest.mark.asyncio
239+
async def test_load_config_file_with_envs():
240+
with tempfile.TemporaryDirectory() as proj_root:
241+
os.makedirs(os.path.join(proj_root, ".vectorcode"), exist_ok=True)
242+
with (
243+
open(
244+
os.path.join(proj_root, ".vectorcode", "config.json"), mode="w"
245+
) as fin,
246+
):
247+
fin.writelines(['{"embedding_function": "$DUMMY_EMBEDDING_FUNCTION"}'])
248+
with patch.dict(
249+
os.environ, {"DUMMY_EMBEDDING_FUNCTION": "DummyEmbeddingFunction"}
250+
):
251+
assert (
252+
await load_config_file(proj_root)
253+
).embedding_function == "DummyEmbeddingFunction"
254+
255+
210256
@pytest.mark.asyncio
211257
async def test_load_from_default_config():
212258
for name in ("config.json5", "config.json"):

0 commit comments

Comments
 (0)