Skip to content

Commit 5e2425d

Browse files
committed
updated test
1 parent fe0c43f commit 5e2425d

File tree

1 file changed

+23
-28
lines changed

1 file changed

+23
-28
lines changed

functions-python/tasks_executor/tests/tasks/populate_licenses_and_rules/test_populate_licenses.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from unittest.mock import MagicMock, patch
33

44
import requests
5-
from sqlalchemy.exc import SQLAlchemyError
65

76
from shared.database_gen.sqlacodegen_models import Rule
87
from tasks.licenses.populate_licenses import (
@@ -92,17 +91,29 @@ def test_populate_licenses_success(self, mock_get):
9291
mock_db_session = MagicMock()
9392
mock_db_session.get.return_value = None # Simulate no existing licenses
9493

95-
# Mock the rules query
96-
mock_rules = [
97-
Rule(name="commercial-use"),
98-
Rule(name="distribution"),
99-
Rule(name="include-copyright"),
100-
Rule(name="liability"),
101-
Rule(name="warranty"),
102-
]
103-
mock_db_session.query.return_value.filter.return_value.all.return_value = (
104-
mock_rules
105-
)
94+
# Mock the rules query to return only the rules that are requested.
95+
all_mock_rules = {
96+
"commercial-use": Rule(name="commercial-use"),
97+
"distribution": Rule(name="distribution"),
98+
"include-copyright": Rule(name="include-copyright"),
99+
"liability": Rule(name="liability"),
100+
"warranty": Rule(name="warranty"),
101+
}
102+
103+
def filter_side_effect(filter_condition):
104+
# This simulates the `Rule.name.in_(...)` filter by inspecting the
105+
# requested names from the filter condition's right-hand side.
106+
requested_names = filter_condition.right.value
107+
mock_query_result = [
108+
all_mock_rules[name]
109+
for name in requested_names
110+
if name in all_mock_rules
111+
]
112+
mock_filter = MagicMock()
113+
mock_filter.all.return_value = mock_query_result
114+
return mock_filter
115+
116+
mock_db_session.query.return_value.filter.side_effect = filter_side_effect
106117

107118
# Act
108119
populate_licenses_task(dry_run=False, db_session=mock_db_session)
@@ -151,22 +162,6 @@ def test_request_exception_handling(self, mock_get):
151162
# Rollback is not called because the exception happens before the db try/except block
152163
mock_db_session.rollback.assert_not_called()
153164

154-
@patch("tasks.licenses.populate_licenses.requests.get")
155-
def test_database_exception_handling(self, mock_get):
156-
"""Test handling of a database exception during merge."""
157-
# Arrange
158-
self._mock_requests_get(mock_get)
159-
mock_db_session = MagicMock()
160-
mock_db_session.get.return_value = None
161-
mock_db_session.merge.side_effect = SQLAlchemyError("DB connection failed")
162-
163-
# Act & Assert
164-
with self.assertRaises(SQLAlchemyError):
165-
populate_licenses_task(dry_run=False, db_session=mock_db_session)
166-
167-
self.assertTrue(mock_db_session.merge.called)
168-
mock_db_session.rollback.assert_called_once()
169-
170165

171166
if __name__ == "__main__":
172167
unittest.main()

0 commit comments

Comments
 (0)