|
| 1 | +import asyncio |
1 | 2 | from unittest.mock import AsyncMock, patch |
2 | 3 |
|
3 | 4 | import pytest |
4 | | -from chromadb.api.types import IncludeEnum |
5 | | -from chromadb.errors import InvalidCollectionException |
6 | 5 |
|
7 | 6 | from vectorcode.cli_utils import Config |
| 7 | +from vectorcode.database.types import FileInCollection |
8 | 8 | from vectorcode.subcommands.update import update |
9 | 9 |
|
10 | 10 |
|
11 | 11 | @pytest.mark.asyncio |
12 | | -async def test_update_success(): |
13 | | - mock_client = AsyncMock() |
14 | | - mock_collection = AsyncMock() |
15 | | - mock_collection.get.return_value = { |
16 | | - "metadatas": [{"path": "file1.py"}, {"path": "file2.py"}] |
17 | | - } |
18 | | - mock_collection.delete = AsyncMock() |
19 | | - mock_client.get_max_batch_size.return_value = 100 |
| 12 | +async def test_update_success(tmp_path): |
| 13 | + """Test successful update with some modified files.""" |
| 14 | + config = Config(project_root=str(tmp_path), pipe=False) |
| 15 | + |
| 16 | + # Mock files in the database |
| 17 | + file1_path = tmp_path / "file1.py" |
| 18 | + file1_path.write_text("content1") |
| 19 | + file2_path = tmp_path / "file2.py" |
| 20 | + file2_path.write_text("new content2") # modified |
| 21 | + file3_path = tmp_path / "file3.py" |
| 22 | + file3_path.write_text("content3") |
| 23 | + |
| 24 | + collection_files = [ |
| 25 | + FileInCollection(path=str(file1_path), sha256="hash1_old"), |
| 26 | + FileInCollection(path=str(file2_path), sha256="hash2_old"), |
| 27 | + FileInCollection(path=str(file3_path), sha256="hash3_old"), |
| 28 | + ] |
20 | 29 |
|
21 | 30 | with ( |
22 | | - patch("vectorcode.subcommands.update.ClientManager"), |
| 31 | + patch("vectorcode.subcommands.update.get_database_connector") as mock_get_db, |
23 | 32 | patch( |
24 | | - "vectorcode.subcommands.update.get_collection", return_value=mock_collection |
25 | | - ), |
26 | | - patch("vectorcode.subcommands.update.verify_ef", return_value=True), |
27 | | - patch("os.path.isfile", return_value=True), |
28 | | - patch( |
29 | | - "vectorcode.subcommands.update.chunked_add", new_callable=AsyncMock |
30 | | - ) as mock_chunked_add, |
31 | | - patch("vectorcode.subcommands.update.show_stats"), |
| 33 | + "vectorcode.subcommands.update.vectorise_worker", new_callable=AsyncMock |
| 34 | + ) as mock_vectorise_worker, |
| 35 | + patch("vectorcode.subcommands.update.show_stats") as mock_show_stats, |
| 36 | + patch("vectorcode.subcommands.update.hash_file") as mock_hash_file, |
32 | 37 | ): |
33 | | - config = Config(project_root="/test/project", pipe=False) |
| 38 | + mock_db = AsyncMock() |
| 39 | + mock_db.list_collection_content.return_value.files = collection_files |
| 40 | + mock_get_db.return_value = mock_db |
| 41 | + |
| 42 | + # file1.py is unchanged, file2.py is changed, file3.py is unchanged |
| 43 | + mock_hash_file.side_effect = ["hash1_old", "hash2_new", "hash3_old"] |
| 44 | + |
34 | 45 | result = await update(config) |
35 | 46 |
|
36 | 47 | assert result == 0 |
37 | | - mock_collection.get.assert_called_once_with(include=[IncludeEnum.metadatas]) |
38 | | - assert mock_chunked_add.call_count == 2 |
39 | | - mock_collection.delete.assert_not_called() |
| 48 | + mock_db.list_collection_content.assert_called_once() |
| 49 | + |
| 50 | + # vectorise_worker should only be called for the modified file (file2.py) |
| 51 | + assert mock_vectorise_worker.call_count == 1 |
| 52 | + # Check that it was called with file2.py |
| 53 | + called_with_file = mock_vectorise_worker.call_args_list[0][0][1] |
| 54 | + assert called_with_file == str(file2_path) |
| 55 | + |
| 56 | + mock_db.check_orphanes.assert_called_once() |
| 57 | + mock_show_stats.assert_called_once() |
40 | 58 |
|
41 | 59 |
|
42 | 60 | @pytest.mark.asyncio |
43 | | -async def test_update_with_orphans(): |
44 | | - mock_client = AsyncMock() |
45 | | - mock_collection = AsyncMock() |
46 | | - mock_collection.get.return_value = { |
47 | | - "metadatas": [{"path": "file1.py"}, {"path": "file2.py"}, {"path": "orphan.py"}] |
48 | | - } |
49 | | - mock_collection.delete = AsyncMock() |
50 | | - mock_client.get_max_batch_size.return_value = 100 |
| 61 | +async def test_update_force(tmp_path): |
| 62 | + """Test update with force=True, all files should be re-vectorised.""" |
| 63 | + config = Config(project_root=str(tmp_path), pipe=False, force=True) |
| 64 | + |
| 65 | + file1_path = tmp_path / "file1.py" |
| 66 | + file1_path.write_text("content1") |
| 67 | + file2_path = tmp_path / "file2.py" |
| 68 | + file2_path.write_text("content2") |
| 69 | + |
| 70 | + collection_files = [ |
| 71 | + FileInCollection(path=str(file1_path), sha256="hash1"), |
| 72 | + FileInCollection(path=str(file2_path), sha256="hash2"), |
| 73 | + ] |
51 | 74 |
|
52 | 75 | with ( |
53 | | - patch("vectorcode.subcommands.update.ClientManager"), |
54 | | - patch( |
55 | | - "vectorcode.subcommands.update.get_collection", return_value=mock_collection |
56 | | - ), |
57 | | - patch("vectorcode.subcommands.update.verify_ef", return_value=True), |
58 | | - patch("os.path.isfile", side_effect=[True, True, False]), |
| 76 | + patch("vectorcode.subcommands.update.get_database_connector") as mock_get_db, |
59 | 77 | patch( |
60 | | - "vectorcode.subcommands.update.chunked_add", new_callable=AsyncMock |
61 | | - ) as mock_chunked_add, |
62 | | - patch("vectorcode.subcommands.update.show_stats"), |
| 78 | + "vectorcode.subcommands.update.vectorise_worker", new_callable=AsyncMock |
| 79 | + ) as mock_vectorise_worker, |
| 80 | + patch("vectorcode.subcommands.update.show_stats") as mock_show_stats, |
| 81 | + patch("vectorcode.subcommands.update.hash_file") as mock_hash_file, |
63 | 82 | ): |
64 | | - config = Config(project_root="/test/project", pipe=False) |
| 83 | + mock_db = AsyncMock() |
| 84 | + mock_db.list_collection_content.return_value.files = collection_files |
| 85 | + mock_get_db.return_value = mock_db |
| 86 | + |
65 | 87 | result = await update(config) |
66 | 88 |
|
67 | 89 | assert result == 0 |
68 | | - mock_collection.get.assert_called_once_with(include=[IncludeEnum.metadatas]) |
69 | | - assert mock_chunked_add.call_count == 2 |
70 | | - mock_collection.delete.assert_called_once_with( |
71 | | - where={"path": {"$in": ["orphan.py"]}} |
72 | | - ) |
| 90 | + mock_db.list_collection_content.assert_called_once() |
73 | 91 |
|
| 92 | + # vectorise_worker should be called for all files |
| 93 | + assert mock_vectorise_worker.call_count == 2 |
| 94 | + mock_hash_file.assert_not_called() # hash_file should not be called with force=True |
74 | 95 |
|
75 | | -@pytest.mark.asyncio |
76 | | -async def test_update_index_error(): |
77 | | - mock_client = AsyncMock() |
78 | | - # mock_collection = AsyncMock() |
| 96 | + mock_db.check_orphanes.assert_called_once() |
| 97 | + mock_show_stats.assert_called_once() |
79 | 98 |
|
80 | | - with ( |
81 | | - patch("vectorcode.subcommands.update.ClientManager") as MockClientManager, |
82 | | - patch("vectorcode.subcommands.update.get_collection", side_effect=IndexError), |
83 | | - patch("sys.stderr"), |
84 | | - ): |
85 | | - MockClientManager.return_value._create_client.return_value = mock_client |
86 | | - config = Config(project_root="/test/project", pipe=False) |
87 | | - result = await update(config) |
88 | 99 |
|
89 | | - assert result == 1 |
| 100 | +@pytest.mark.asyncio |
| 101 | +async def test_update_cancelled(tmp_path): |
| 102 | + """Test update being cancelled.""" |
| 103 | + config = Config(project_root=str(tmp_path), pipe=False) |
90 | 104 |
|
| 105 | + file1_path = tmp_path / "file1.py" |
| 106 | + file1_path.write_text("content1") |
91 | 107 |
|
92 | | -@pytest.mark.asyncio |
93 | | -async def test_update_value_error(): |
94 | | - mock_client = AsyncMock() |
95 | | - # mock_collection = AsyncMock() |
| 108 | + collection_files = [ |
| 109 | + FileInCollection(path=str(file1_path), sha256="hash1_old"), |
| 110 | + ] |
96 | 111 |
|
97 | 112 | with ( |
98 | | - patch("vectorcode.subcommands.update.ClientManager") as MockClientManager, |
99 | | - patch("vectorcode.subcommands.update.get_collection", side_effect=ValueError), |
100 | | - patch("sys.stderr"), |
| 113 | + patch("vectorcode.subcommands.update.get_database_connector") as mock_get_db, |
| 114 | + patch( |
| 115 | + "vectorcode.subcommands.update.vectorise_worker", new_callable=AsyncMock |
| 116 | + ) as mock_vectorise_worker, |
| 117 | + patch("vectorcode.subcommands.update.hash_file", return_value="hash1_new"), |
101 | 118 | ): |
102 | | - MockClientManager.return_value._create_client.return_value = mock_client |
103 | | - config = Config(project_root="/test/project", pipe=False) |
| 119 | + mock_db = AsyncMock() |
| 120 | + mock_db.list_collection_content.return_value.files = collection_files |
| 121 | + mock_get_db.return_value = mock_db |
| 122 | + |
| 123 | + mock_vectorise_worker.side_effect = asyncio.CancelledError |
| 124 | + |
104 | 125 | result = await update(config) |
105 | 126 |
|
106 | 127 | assert result == 1 |
| 128 | + mock_db.check_orphanes.assert_not_called() |
107 | 129 |
|
108 | 130 |
|
109 | 131 | @pytest.mark.asyncio |
110 | | -async def test_update_invalid_collection_exception(): |
111 | | - mock_client = AsyncMock() |
112 | | - # mock_collection = AsyncMock() |
| 132 | +async def test_update_empty_collection(tmp_path): |
| 133 | + """Test update with an empty collection.""" |
| 134 | + config = Config(project_root=str(tmp_path), pipe=False) |
113 | 135 |
|
114 | 136 | with ( |
115 | | - patch("vectorcode.subcommands.update.ClientManager") as MockClientManager, |
| 137 | + patch("vectorcode.subcommands.update.get_database_connector") as mock_get_db, |
116 | 138 | patch( |
117 | | - "vectorcode.subcommands.update.get_collection", |
118 | | - side_effect=InvalidCollectionException, |
119 | | - ), |
120 | | - patch("sys.stderr"), |
| 139 | + "vectorcode.subcommands.update.vectorise_worker", new_callable=AsyncMock |
| 140 | + ) as mock_vectorise_worker, |
| 141 | + patch("vectorcode.subcommands.update.show_stats") as mock_show_stats, |
121 | 142 | ): |
122 | | - MockClientManager.return_value._create_client.return_value = mock_client |
123 | | - config = Config(project_root="/test/project", pipe=False) |
| 143 | + mock_db = AsyncMock() |
| 144 | + mock_db.list_collection_content.return_value.files = [] |
| 145 | + mock_get_db.return_value = mock_db |
| 146 | + |
124 | 147 | result = await update(config) |
125 | 148 |
|
126 | | - assert result == 1 |
| 149 | + assert result == 0 |
| 150 | + mock_vectorise_worker.assert_not_called() |
| 151 | + mock_db.check_orphanes.assert_called_once() |
| 152 | + mock_show_stats.assert_called_once() |
0 commit comments