|
| 1 | +import sys |
| 2 | +from unittest.mock import MagicMock, patch |
| 3 | + |
1 | 4 | import pytest |
2 | 5 |
|
3 | 6 | from vectorcode.cli_utils import Config |
4 | 7 | from vectorcode.database import get_database_connector |
5 | | -from vectorcode.database.chroma0 import ChromaDB0Connector |
6 | 8 |
|
7 | 9 |
|
8 | | -def test_get_database_connector(): |
9 | | - assert isinstance( |
10 | | - get_database_connector(Config(db_type="ChromaDB0")), ChromaDB0Connector |
11 | | - ) |
| 10 | +@pytest.mark.parametrize( |
| 11 | + "db_type, module_to_mock, class_name", |
| 12 | + [ |
| 13 | + ("ChromaDB0", "vectorcode.database.chroma0", "ChromaDB0Connector"), |
| 14 | + # To test a new connector, add a tuple here following the same pattern. |
| 15 | + # e.g. ("NewDB", "vectorcode.database.newdb", "NewDBConnector"), |
| 16 | + ], |
| 17 | +) |
| 18 | +def test_get_database_connector(db_type, module_to_mock, class_name): |
| 19 | + """ |
| 20 | + Tests that get_database_connector can correctly return a connector |
| 21 | + for a given db_type. This test is parameterized to be easily |
| 22 | + extensible for new database connectors. |
| 23 | + """ |
| 24 | + mock_connector_class = MagicMock() |
| 25 | + mock_module = MagicMock() |
| 26 | + setattr(mock_module, class_name, mock_connector_class) |
| 27 | + |
| 28 | + # Use patch.dict to temporarily replace the module in sys.modules. |
| 29 | + # This prevents the actual module from being imported, avoiding |
| 30 | + # errors if its dependencies are not installed. |
| 31 | + with patch.dict(sys.modules, {module_to_mock: mock_module}): |
| 32 | + config = Config(db_type=db_type) |
| 33 | + connector = get_database_connector(config) |
| 34 | + |
| 35 | + # Verify that the create method was called on our mock class |
| 36 | + mock_connector_class.create.assert_called_once_with(config) |
| 37 | + |
| 38 | + # Verify that the returned connector is the one from our mock |
| 39 | + assert connector == mock_connector_class.create.return_value |
12 | 40 |
|
13 | 41 |
|
14 | 42 | def test_get_database_connector_invalid_type(): |
|
0 commit comments