|
2 | 2 | from unittest.mock import MagicMock, patch |
3 | 3 |
|
4 | 4 | import requests |
5 | | -from sqlalchemy.exc import SQLAlchemyError |
6 | 5 |
|
7 | 6 | from shared.database_gen.sqlacodegen_models import Rule |
8 | 7 | from tasks.licenses.populate_licenses import ( |
@@ -92,17 +91,29 @@ def test_populate_licenses_success(self, mock_get): |
92 | 91 | mock_db_session = MagicMock() |
93 | 92 | mock_db_session.get.return_value = None # Simulate no existing licenses |
94 | 93 |
|
95 | | - # Mock the rules query |
96 | | - mock_rules = [ |
97 | | - Rule(name="commercial-use"), |
98 | | - Rule(name="distribution"), |
99 | | - Rule(name="include-copyright"), |
100 | | - Rule(name="liability"), |
101 | | - Rule(name="warranty"), |
102 | | - ] |
103 | | - mock_db_session.query.return_value.filter.return_value.all.return_value = ( |
104 | | - mock_rules |
105 | | - ) |
| 94 | + # Mock the rules query to return only the rules that are requested. |
| 95 | + all_mock_rules = { |
| 96 | + "commercial-use": Rule(name="commercial-use"), |
| 97 | + "distribution": Rule(name="distribution"), |
| 98 | + "include-copyright": Rule(name="include-copyright"), |
| 99 | + "liability": Rule(name="liability"), |
| 100 | + "warranty": Rule(name="warranty"), |
| 101 | + } |
| 102 | + |
| 103 | + def filter_side_effect(filter_condition): |
| 104 | + # This simulates the `Rule.name.in_(...)` filter by inspecting the |
| 105 | + # requested names from the filter condition's right-hand side. |
| 106 | + requested_names = filter_condition.right.value |
| 107 | + mock_query_result = [ |
| 108 | + all_mock_rules[name] |
| 109 | + for name in requested_names |
| 110 | + if name in all_mock_rules |
| 111 | + ] |
| 112 | + mock_filter = MagicMock() |
| 113 | + mock_filter.all.return_value = mock_query_result |
| 114 | + return mock_filter |
| 115 | + |
| 116 | + mock_db_session.query.return_value.filter.side_effect = filter_side_effect |
106 | 117 |
|
107 | 118 | # Act |
108 | 119 | populate_licenses_task(dry_run=False, db_session=mock_db_session) |
@@ -151,22 +162,6 @@ def test_request_exception_handling(self, mock_get): |
151 | 162 | # Rollback is not called because the exception happens before the db try/except block |
152 | 163 | mock_db_session.rollback.assert_not_called() |
153 | 164 |
|
154 | | - @patch("tasks.licenses.populate_licenses.requests.get") |
155 | | - def test_database_exception_handling(self, mock_get): |
156 | | - """Test handling of a database exception during merge.""" |
157 | | - # Arrange |
158 | | - self._mock_requests_get(mock_get) |
159 | | - mock_db_session = MagicMock() |
160 | | - mock_db_session.get.return_value = None |
161 | | - mock_db_session.merge.side_effect = SQLAlchemyError("DB connection failed") |
162 | | - |
163 | | - # Act & Assert |
164 | | - with self.assertRaises(SQLAlchemyError): |
165 | | - populate_licenses_task(dry_run=False, db_session=mock_db_session) |
166 | | - |
167 | | - self.assertTrue(mock_db_session.merge.called) |
168 | | - mock_db_session.rollback.assert_called_once() |
169 | | - |
170 | 165 |
|
171 | 166 | if __name__ == "__main__": |
172 | 167 | unittest.main() |
0 commit comments