Skip to content

Commit fe0c43f

Browse files
committed
redesigned test
1 parent 6dd63c7 commit fe0c43f

File tree

1 file changed

+74
-91
lines changed

1 file changed

+74
-91
lines changed

functions-python/tasks_executor/tests/tasks/populate_licenses_and_rules/test_populate_licenses.py

Lines changed: 74 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,13 @@
22
from unittest.mock import MagicMock, patch
33

44
import requests
5-
from sqlalchemy.dialects.postgresql import JSONB
6-
from sqlalchemy.ext.compiler import compiles
7-
from sqlalchemy.sql.functions import now
8-
from shared.database.database import Session
9-
from shared.database_gen.sqlacodegen_models import License, Rule
10-
from tasks.licenses.populate_licenses import populate_licenses_task
11-
12-
13-
# This compilation rule is necessary to make the JSONB type, which is PostgreSQL-specific,
14-
# compatible with the in-memory SQLite database used for testing. It tells SQLAlchemy
15-
# to treat JSONB as TEXT when running against a SQLite backend.
16-
@compiles(JSONB, "sqlite")
17-
def compile_jsonb_for_sqlite(element, compiler, **kw):
18-
return "TEXT"
19-
20-
21-
# This compilation rule translates the PostgreSQL-specific `now()` function into
22-
# the SQLite-compatible `CURRENT_TIMESTAMP` function during test schema creation.
23-
@compiles(now, "sqlite")
24-
def compile_now_for_sqlite(element, compiler, **kw):
25-
return "CURRENT_TIMESTAMP"
5+
from sqlalchemy.exc import SQLAlchemyError
266

7+
from shared.database_gen.sqlacodegen_models import Rule
8+
from tasks.licenses.populate_licenses import (
9+
LICENSES_API_URL,
10+
populate_licenses_task,
11+
)
2712

2813
# Mock data for GitHub API responses
2914
MOCK_LICENSE_LIST = [
@@ -79,34 +64,10 @@ def compile_now_for_sqlite(element, compiler, **kw):
7964

8065

8166
class TestPopulateLicenses(unittest.TestCase):
82-
def setUp(self):
83-
# Create an in-memory SQLite database for testing
84-
self.session = Session(bind=self.engine)
85-
86-
def tearDown(self):
87-
self.session.close()
88-
89-
@classmethod
90-
def setUpClass(cls):
91-
from sqlalchemy import create_engine
92-
93-
from shared.database_gen.sqlacodegen_models import Base
94-
95-
cls.engine = create_engine("sqlite:///:memory:")
96-
Base.metadata.create_all(cls.engine)
97-
98-
@classmethod
99-
def tearDownClass(cls):
100-
from shared.database_gen.sqlacodegen_models import Base
101-
102-
Base.metadata.drop_all(cls.engine)
103-
10467
def _mock_requests_get(self, mock_get):
10568
"""Helper to configure mock for requests.get."""
10669
mock_responses = {
107-
"https://api.github.com/repos/MobilityData/licenses-aas/contents/data/licenses": MagicMock(
108-
json=lambda: MOCK_LICENSE_LIST
109-
),
70+
LICENSES_API_URL: MagicMock(json=lambda: MOCK_LICENSE_LIST),
11071
"http://mockurl/MIT.json": MagicMock(json=lambda: MOCK_LICENSE_MIT),
11172
"http://mockurl/BSD-3-Clause.json": MagicMock(
11273
json=lambda: MOCK_LICENSE_BSD
@@ -125,64 +86,86 @@ def get_side_effect(url, timeout=None):
12586

12687
@patch("tasks.licenses.populate_licenses.requests.get")
12788
def test_populate_licenses_success(self, mock_get):
89+
"""Test successful population of licenses."""
90+
# Arrange
12891
self._mock_requests_get(mock_get)
129-
130-
# Pre-populate rules
131-
rules_to_add = [
132-
Rule(id="commercial-use", name="commercial-use", type="permission"),
133-
Rule(id="distribution", name="distribution", type="permission"),
134-
Rule(id="include-copyright", name="include-copyright", type="condition"),
135-
Rule(id="liability", name="liability", type="limitation"),
136-
Rule(id="warranty", name="warranty", type="limitation"),
92+
mock_db_session = MagicMock()
93+
mock_db_session.get.return_value = None # Simulate no existing licenses
94+
95+
# Mock the rules query
96+
mock_rules = [
97+
Rule(name="commercial-use"),
98+
Rule(name="distribution"),
99+
Rule(name="include-copyright"),
100+
Rule(name="liability"),
101+
Rule(name="warranty"),
137102
]
138-
self.session.add_all(rules_to_add)
139-
self.session.commit()
103+
mock_db_session.query.return_value.filter.return_value.all.return_value = (
104+
mock_rules
105+
)
106+
107+
# Act
108+
populate_licenses_task(dry_run=False, db_session=mock_db_session)
140109

141-
populate_licenses_task(dry_run=False, db_session=self.session)
110+
# Assert
111+
self.assertEqual(mock_db_session.merge.call_count, 2)
112+
mock_db_session.rollback.assert_not_called()
142113

143-
licenses = self.session.query(License).order_by(License.id).all()
144-
self.assertEqual(len(licenses), 2)
114+
# Check that merge was called with correctly constructed License objects
115+
call_args_list = mock_db_session.merge.call_args_list
116+
merged_licenses = [arg.args[0] for arg in call_args_list]
145117

146-
# Check MIT License
147-
mit_license = licenses[1]
148-
self.assertEqual(mit_license.id, "MIT")
118+
mit_license = next((lic for lic in merged_licenses if lic.id == "MIT"), None)
119+
self.assertIsNotNone(mit_license)
149120
self.assertEqual(mit_license.name, "MIT License")
150121
self.assertTrue(mit_license.is_spdx)
151122
self.assertEqual(len(mit_license.rules), 3)
152-
rule_names = sorted([rule.name for rule in mit_license.rules])
153-
self.assertEqual(
154-
rule_names, ["commercial-use", "distribution", "include-copyright"]
155-
)
156123

157-
# Check BSD License
158-
bsd_license = licenses[0]
159-
self.assertEqual(bsd_license.id, "BSD-3-Clause")
160-
self.assertEqual(bsd_license.name, "BSD 3-Clause License")
161-
self.assertEqual(len(bsd_license.rules), 3)
162-
rule_names = sorted([rule.name for rule in bsd_license.rules])
163-
self.assertEqual(rule_names, ["commercial-use", "liability", "warranty"])
124+
@patch("tasks.licenses.populate_licenses.requests.get")
125+
def test_populate_licenses_dry_run(self, mock_get):
126+
"""Test that no database changes are made during a dry run."""
127+
# Arrange
128+
self._mock_requests_get(mock_get)
129+
mock_db_session = MagicMock()
130+
131+
# Act
132+
populate_licenses_task(dry_run=True, db_session=mock_db_session)
133+
134+
# Assert
135+
mock_db_session.get.assert_not_called()
136+
mock_db_session.merge.assert_not_called()
137+
mock_db_session.rollback.assert_not_called()
164138

165139
@patch("tasks.licenses.populate_licenses.requests.get")
166-
def test_update_existing_license(self, mock_get):
140+
def test_request_exception_handling(self, mock_get):
141+
"""Test handling of a requests exception."""
142+
# Arrange
143+
mock_get.side_effect = requests.exceptions.RequestException("Network Error")
144+
mock_db_session = MagicMock()
145+
146+
# Act & Assert
147+
with self.assertRaises(requests.exceptions.RequestException):
148+
populate_licenses_task(dry_run=False, db_session=mock_db_session)
149+
150+
mock_db_session.merge.assert_not_called()
151+
# Rollback is not called because the exception happens before the db try/except block
152+
mock_db_session.rollback.assert_not_called()
153+
154+
@patch("tasks.licenses.populate_licenses.requests.get")
155+
def test_database_exception_handling(self, mock_get):
156+
"""Test handling of a database exception during merge."""
157+
# Arrange
167158
self._mock_requests_get(mock_get)
159+
mock_db_session = MagicMock()
160+
mock_db_session.get.return_value = None
161+
mock_db_session.merge.side_effect = SQLAlchemyError("DB connection failed")
168162

169-
# Pre-populate license and a rule
170-
existing_license = License(
171-
id="MIT", name="Old MIT Name", url="http://oldurl.com"
172-
)
173-
existing_rule = Rule(id="private-use", name="private-use", type="permission")
174-
existing_license.rules.append(existing_rule)
175-
self.session.add(existing_license)
176-
self.session.commit()
177-
178-
# Run the task to update
179-
populate_licenses_task(dry_run=False, db_session=self.session)
180-
181-
updated_license = self.session.query(License).filter_by(id="MIT").one()
182-
self.assertEqual(updated_license.name, "MIT License")
183-
self.assertEqual(updated_license.url, "https://opensource.org/licenses/MIT")
184-
# Check that rules are updated, not appended
185-
self.assertNotEqual(len(updated_license.rules), 4)
163+
# Act & Assert
164+
with self.assertRaises(SQLAlchemyError):
165+
populate_licenses_task(dry_run=False, db_session=mock_db_session)
166+
167+
self.assertTrue(mock_db_session.merge.called)
168+
mock_db_session.rollback.assert_called_once()
186169

187170

188171
if __name__ == "__main__":

0 commit comments

Comments
 (0)