Skip to content

Commit ac085fa

Browse files
committed
tests(chroma0): Add more tests to chroma0
1 parent 28cc7f3 commit ac085fa

File tree

1 file changed

+348
-0
lines changed

1 file changed

+348
-0
lines changed

tests/database/test_chroma0.py

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import tempfile
33
from unittest.mock import AsyncMock, MagicMock, patch
44

5+
import httpx
56
import pytest
67
from chromadb.api.types import QueryResult
78
from chromadb.errors import InvalidCollectionException
@@ -11,7 +12,11 @@
1112
from vectorcode.database import types
1213
from vectorcode.database.chroma0 import (
1314
ChromaDB0Connector,
15+
_Chroma0ClientManager,
1416
_convert_chroma_query_results,
17+
_start_server,
18+
_try_server,
19+
_wait_for_server,
1520
)
1621
from vectorcode.database.errors import CollectionNotFoundError
1722

@@ -21,10 +26,12 @@ def mock_config():
2126
with tempfile.TemporaryDirectory() as tmpdir:
2227
yield Config(
2328
project_root=tmpdir,
29+
embedding_function="default",
2430
db_params={
2531
"db_url": "http://localhost:1234",
2632
"db_path": os.path.join(tmpdir, "db"),
2733
"db_log_path": os.path.join(tmpdir, "log"),
34+
"db_settings": {},
2835
},
2936
)
3037

@@ -457,3 +464,344 @@ async def test_get_chunks_generic_exception(mock_config):
457464
with pytest.raises(Exception) as excinfo:
458465
await connector.get_chunks(os.path.join(mock_config.project_root, "file1"))
459466
assert "test error" in str(excinfo.value)
467+
468+
469+
@pytest.mark.asyncio
470+
async def test_try_server_success():
471+
"""Test _try_server when the server is running."""
472+
with patch("httpx.AsyncClient") as mock_client:
473+
mock_response = AsyncMock()
474+
mock_response.status_code = 200
475+
mock_client.return_value.__aenter__.return_value.get.return_value = (
476+
mock_response
477+
)
478+
479+
assert await _try_server("http://localhost:8000") is True
480+
481+
482+
@pytest.mark.asyncio
483+
async def test_try_server_failure():
484+
"""Test _try_server when the server is not running."""
485+
with patch("httpx.AsyncClient") as mock_client:
486+
mock_client.return_value.__aenter__.return_value.get.side_effect = (
487+
httpx.ConnectError("test")
488+
)
489+
490+
assert await _try_server("http://localhost:8000") is False
491+
492+
493+
@pytest.mark.asyncio
494+
async def test_wait_for_server_success():
495+
"""Test _wait_for_server when the server starts."""
496+
with patch(
497+
"vectorcode.database.chroma0._try_server", new_callable=AsyncMock
498+
) as mock_try_server:
499+
mock_try_server.side_effect = [False, True]
500+
await _wait_for_server("http://localhost:8000", timeout=1)
501+
assert mock_try_server.call_count == 2
502+
503+
504+
@pytest.mark.asyncio
505+
async def test_wait_for_server_timeout():
506+
"""Test _wait_for_server when the server does not start."""
507+
with patch(
508+
"vectorcode.database.chroma0._try_server", new_callable=AsyncMock
509+
) as mock_try_server:
510+
mock_try_server.return_value = False
511+
with pytest.raises(TimeoutError):
512+
await _wait_for_server("http://localhost:8000", timeout=0.2)
513+
514+
515+
@pytest.mark.asyncio
516+
async def test_start_server(mock_config):
517+
"""Test the _start_server function."""
518+
with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_exec:
519+
mock_process = AsyncMock()
520+
mock_exec.return_value = mock_process
521+
with patch(
522+
"vectorcode.database.chroma0._wait_for_server", new_callable=AsyncMock
523+
) as mock_wait:
524+
process = await _start_server(mock_config)
525+
assert process == mock_process
526+
mock_exec.assert_called_once()
527+
mock_wait.assert_called_once()
528+
529+
530+
@pytest.mark.asyncio
531+
async def test_client_manager_get_client_new_server(mock_config):
532+
"""Test get_client when a new server needs to be started."""
533+
with patch("atexit.register"):
534+
manager = _Chroma0ClientManager()
535+
manager.clear()
536+
with (
537+
patch(
538+
"vectorcode.database.chroma0._try_server", new_callable=AsyncMock
539+
) as mock_try_server,
540+
patch(
541+
"vectorcode.database.chroma0._start_server", new_callable=AsyncMock
542+
) as mock_start_server,
543+
patch(
544+
"vectorcode.database.chroma0._Chroma0ClientManager._create_client",
545+
new_callable=AsyncMock,
546+
) as mock_create_client,
547+
):
548+
mock_try_server.return_value = False
549+
mock_process = MagicMock()
550+
mock_process.returncode = None
551+
mock_start_server.return_value = mock_process
552+
mock_client = AsyncMock()
553+
mock_client.get_version.return_value = "0.1.0"
554+
mock_create_client.return_value = mock_client
555+
556+
async with manager.get_client(mock_config, need_lock=False) as client:
557+
assert client == mock_client
558+
assert manager.get_processes() == [mock_process]
559+
560+
manager.kill_servers()
561+
mock_process.terminate.assert_called_once()
562+
manager.clear()
563+
564+
565+
@pytest.mark.asyncio
566+
async def test_client_manager_get_client_existing_server(mock_config):
567+
"""Test get_client with an existing server."""
568+
manager = _Chroma0ClientManager()
569+
manager.clear()
570+
with (
571+
patch(
572+
"vectorcode.database.chroma0._try_server", new_callable=AsyncMock
573+
) as mock_try_server,
574+
patch(
575+
"vectorcode.database.chroma0._Chroma0ClientManager._create_client",
576+
new_callable=AsyncMock,
577+
) as mock_create_client,
578+
):
579+
mock_try_server.return_value = True
580+
mock_client = AsyncMock()
581+
mock_client.get_version.return_value = "0.1.0"
582+
mock_create_client.return_value = mock_client
583+
584+
async with manager.get_client(mock_config, need_lock=False) as client:
585+
assert client == mock_client
586+
assert not manager.get_processes()
587+
manager.clear()
588+
589+
590+
@pytest.mark.asyncio
591+
async def test_create_client(mock_config):
592+
"""Test the _create_client method."""
593+
manager = _Chroma0ClientManager()
594+
with patch("chromadb.AsyncHttpClient", new_callable=AsyncMock) as mock_http_client:
595+
await manager._create_client(mock_config)
596+
mock_http_client.assert_called_once()
597+
598+
599+
@pytest.mark.asyncio
600+
async def test_client_manager_get_client_with_lock(mock_config):
601+
"""Test get_client with a lock."""
602+
with patch("atexit.register"):
603+
manager = _Chroma0ClientManager()
604+
manager.clear()
605+
with (
606+
patch(
607+
"vectorcode.database.chroma0._try_server",
608+
new_callable=AsyncMock,
609+
return_value=False,
610+
),
611+
patch(
612+
"vectorcode.database.chroma0._start_server", new_callable=AsyncMock
613+
) as mock_start_server,
614+
patch(
615+
"vectorcode.database.chroma0._Chroma0ClientManager._create_client",
616+
new_callable=AsyncMock,
617+
) as mock_create_client,
618+
patch("vectorcode.database.chroma0.LockManager") as mock_lock_manager,
619+
):
620+
mock_process = MagicMock()
621+
mock_process.returncode = None
622+
mock_start_server.return_value = mock_process
623+
mock_client = AsyncMock()
624+
mock_client.get_version.return_value = "0.1.0"
625+
mock_create_client.return_value = mock_client
626+
mock_lock = AsyncMock()
627+
mock_lock_manager.return_value.get_lock.return_value = mock_lock
628+
629+
async with manager.get_client(mock_config, need_lock=True) as client:
630+
assert client == mock_client
631+
632+
mock_lock.acquire.assert_called_once()
633+
mock_lock.release.assert_called_once()
634+
635+
manager.kill_servers()
636+
manager.clear()
637+
638+
639+
@pytest.mark.asyncio
640+
async def test_query_no_n_result(mock_config):
641+
"""Test the query method without n_result."""
642+
connector = ChromaDB0Connector(mock_config)
643+
connector._configs.query = ["test query"]
644+
connector._configs.n_result = None
645+
connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]])
646+
647+
mock_collection = AsyncMock()
648+
mock_collection.query.return_value = {
649+
"documents": [["doc1"]],
650+
"distances": [[0.1]],
651+
"metadatas": [[{"path": os.path.join(mock_config.project_root, "file1")}]],
652+
"ids": [["id1"]],
653+
}
654+
connector._create_or_get_collection = AsyncMock(return_value=mock_collection)
655+
mock_content = MagicMock()
656+
mock_content.chunks = [1] * 10
657+
connector.list_collection_content = AsyncMock(return_value=mock_content)
658+
659+
with patch(
660+
"vectorcode.database.chroma0._convert_chroma_query_results"
661+
) as mock_convert:
662+
mock_convert.return_value = ["converted_results"]
663+
await connector.query()
664+
_, kwargs = mock_collection.query.call_args
665+
assert kwargs["n_results"] == 10
666+
667+
668+
@pytest.mark.asyncio
669+
async def test_create_or_get_collection_exists(mock_config):
670+
"""Test _create_or_get_collection when collection exists and allow_create is True."""
671+
connector = ChromaDB0Connector(mock_config)
672+
with (
673+
patch(
674+
"vectorcode.database.chroma0._Chroma0ClientManager.get_client"
675+
) as mock_get_client,
676+
patch("os.environ.get", return_value="DEFAULT_USER"),
677+
):
678+
mock_client = AsyncMock()
679+
mock_collection = AsyncMock()
680+
mock_collection.metadata = {
681+
"path": os.path.abspath(str(mock_config.project_root)),
682+
"hostname": "test-host",
683+
"created-by": "VectorCode",
684+
"username": "DEFAULT_USER",
685+
"embedding_function": "default",
686+
"hnsw:M": 64,
687+
}
688+
mock_client.get_or_create_collection.return_value = mock_collection
689+
mock_get_client.return_value.__aenter__.return_value = mock_client
690+
with patch("socket.gethostname", return_value="test-host"):
691+
collection = await connector._create_or_get_collection(
692+
"collection_path", allow_create=True
693+
)
694+
assert collection == mock_collection
695+
mock_client.get_or_create_collection.assert_called_once()
696+
697+
698+
@pytest.mark.asyncio
699+
async def test_list_collection_content_with_id(mock_config):
700+
"""Test the list_collection_content method with collection_id."""
701+
connector = ChromaDB0Connector(mock_config)
702+
with patch(
703+
"vectorcode.database.chroma0._Chroma0ClientManager.get_client"
704+
) as mock_get_client:
705+
mock_client = AsyncMock()
706+
mock_collection = AsyncMock()
707+
mock_collection.get.return_value = {
708+
"metadatas": [
709+
{
710+
"path": os.path.join(mock_config.project_root, "file1"),
711+
"sha256": "hash1",
712+
}
713+
],
714+
"documents": ["doc1"],
715+
"ids": ["id1"],
716+
}
717+
mock_client.get_collection.return_value = mock_collection
718+
mock_get_client.return_value.__aenter__.return_value = mock_client
719+
720+
content = await connector.list_collection_content(collection_id="test_id")
721+
assert len(content.files) == 1
722+
assert len(content.chunks) == 1
723+
mock_client.get_collection.assert_called_once_with("test_id")
724+
725+
726+
@pytest.mark.asyncio
727+
async def test_query_with_exclude_and_include_chunk(mock_config):
728+
"""Test query with exclude paths and include chunk."""
729+
connector = ChromaDB0Connector(mock_config)
730+
connector._configs.query = ["test query"]
731+
connector._configs.query_exclude = ["file2"]
732+
connector._configs.include = [QueryInclude.chunk]
733+
connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]])
734+
735+
mock_collection = AsyncMock()
736+
mock_collection.query.return_value = {
737+
"documents": [["doc1"]],
738+
"distances": [[0.1]],
739+
"metadatas": [[{"path": "file1"}]],
740+
"ids": [["id1"]],
741+
}
742+
connector._create_or_get_collection = AsyncMock(return_value=mock_collection)
743+
744+
with patch(
745+
"vectorcode.database.chroma0._convert_chroma_query_results"
746+
) as mock_convert:
747+
mock_convert.return_value = ["converted_results"]
748+
await connector.query()
749+
mock_collection.query.assert_called_once()
750+
_, kwargs = mock_collection.query.call_args
751+
assert "where" in kwargs
752+
assert kwargs["where"] == {
753+
"$and": [{"path": {"$nin": ["file2"]}}, {"start": {"$gte": 0}}]
754+
}
755+
756+
757+
@pytest.mark.asyncio
758+
async def test_create_or_get_collection_metadata_mismatch(mock_config):
759+
"""Test _create_or_get_collection when metadata mismatches."""
760+
connector = ChromaDB0Connector(mock_config)
761+
with (
762+
patch(
763+
"vectorcode.database.chroma0._Chroma0ClientManager.get_client"
764+
) as mock_get_client,
765+
patch("os.environ.get", return_value="DEFAULT_USER"),
766+
):
767+
mock_client = AsyncMock()
768+
mock_collection = AsyncMock()
769+
mock_collection.metadata = {
770+
"path": os.path.abspath(str(mock_config.project_root)),
771+
"hostname": "test-host",
772+
"created-by": "VectorCode",
773+
"username": "DIFFERENT_USER",
774+
"embedding_function": "default",
775+
"hnsw:M": 64,
776+
}
777+
mock_client.get_or_create_collection.return_value = mock_collection
778+
mock_get_client.return_value.__aenter__.return_value = mock_client
779+
with patch("socket.gethostname", return_value="test-host"):
780+
with pytest.raises(AssertionError):
781+
await connector._create_or_get_collection(
782+
"collection_path", allow_create=True
783+
)
784+
785+
786+
@pytest.mark.asyncio
787+
async def test_delete_no_matching_files(mock_config):
788+
"""Test delete with no matching files."""
789+
connector = ChromaDB0Connector(mock_config)
790+
mock_collection = AsyncMock()
791+
connector._create_or_get_collection = AsyncMock(return_value=mock_collection)
792+
connector.list_collection_content = AsyncMock(
793+
return_value=MagicMock(files=[MagicMock(path="file1")])
794+
)
795+
mock_config.rm_paths = ["file2"]
796+
797+
def mock_expand_path(path, absolute):
798+
return path
799+
800+
with (
801+
patch("vectorcode.database.chroma0.expand_globs", return_value=["file2"]),
802+
patch("vectorcode.database.chroma0.expand_path", side_effect=mock_expand_path),
803+
patch("os.path.isfile", return_value=True),
804+
):
805+
deleted_count = await connector.delete()
806+
assert deleted_count == 0
807+
mock_collection.delete.assert_not_called()

0 commit comments

Comments
 (0)