22from unittest .mock import MagicMock , patch
33
44import requests
5- from sqlalchemy .dialects .postgresql import JSONB
6- from sqlalchemy .ext .compiler import compiles
7- from sqlalchemy .sql .functions import now
8- from shared .database .database import Session
9- from shared .database_gen .sqlacodegen_models import License , Rule
10- from tasks .licenses .populate_licenses import populate_licenses_task
11-
12-
13- # This compilation rule is necessary to make the JSONB type, which is PostgreSQL-specific,
14- # compatible with the in-memory SQLite database used for testing. It tells SQLAlchemy
15- # to treat JSONB as TEXT when running against a SQLite backend.
16- @compiles (JSONB , "sqlite" )
17- def compile_jsonb_for_sqlite (element , compiler , ** kw ):
18- return "TEXT"
19-
20-
21- # This compilation rule translates the PostgreSQL-specific `now()` function into
22- # the SQLite-compatible `CURRENT_TIMESTAMP` function during test schema creation.
23- @compiles (now , "sqlite" )
24- def compile_now_for_sqlite (element , compiler , ** kw ):
25- return "CURRENT_TIMESTAMP"
5+ from sqlalchemy .exc import SQLAlchemyError
266
7+ from shared .database_gen .sqlacodegen_models import Rule
8+ from tasks .licenses .populate_licenses import (
9+ LICENSES_API_URL ,
10+ populate_licenses_task ,
11+ )
2712
2813# Mock data for GitHub API responses
2914MOCK_LICENSE_LIST = [
@@ -79,34 +64,10 @@ def compile_now_for_sqlite(element, compiler, **kw):
7964
8065
8166class TestPopulateLicenses (unittest .TestCase ):
82- def setUp (self ):
83- # Create an in-memory SQLite database for testing
84- self .session = Session (bind = self .engine )
85-
86- def tearDown (self ):
87- self .session .close ()
88-
89- @classmethod
90- def setUpClass (cls ):
91- from sqlalchemy import create_engine
92-
93- from shared .database_gen .sqlacodegen_models import Base
94-
95- cls .engine = create_engine ("sqlite:///:memory:" )
96- Base .metadata .create_all (cls .engine )
97-
98- @classmethod
99- def tearDownClass (cls ):
100- from shared .database_gen .sqlacodegen_models import Base
101-
102- Base .metadata .drop_all (cls .engine )
103-
10467 def _mock_requests_get (self , mock_get ):
10568 """Helper to configure mock for requests.get."""
10669 mock_responses = {
107- "https://api.github.com/repos/MobilityData/licenses-aas/contents/data/licenses" : MagicMock (
108- json = lambda : MOCK_LICENSE_LIST
109- ),
70+ LICENSES_API_URL : MagicMock (json = lambda : MOCK_LICENSE_LIST ),
11071 "http://mockurl/MIT.json" : MagicMock (json = lambda : MOCK_LICENSE_MIT ),
11172 "http://mockurl/BSD-3-Clause.json" : MagicMock (
11273 json = lambda : MOCK_LICENSE_BSD
@@ -125,64 +86,86 @@ def get_side_effect(url, timeout=None):
12586
12687 @patch ("tasks.licenses.populate_licenses.requests.get" )
12788 def test_populate_licenses_success (self , mock_get ):
89+ """Test successful population of licenses."""
90+ # Arrange
12891 self ._mock_requests_get (mock_get )
129-
130- # Pre-populate rules
131- rules_to_add = [
132- Rule (id = "commercial-use" , name = "commercial-use" , type = "permission" ),
133- Rule (id = "distribution" , name = "distribution" , type = "permission" ),
134- Rule (id = "include-copyright" , name = "include-copyright" , type = "condition" ),
135- Rule (id = "liability" , name = "liability" , type = "limitation" ),
136- Rule (id = "warranty" , name = "warranty" , type = "limitation" ),
92+ mock_db_session = MagicMock ()
93+ mock_db_session .get .return_value = None # Simulate no existing licenses
94+
95+ # Mock the rules query
96+ mock_rules = [
97+ Rule (name = "commercial-use" ),
98+ Rule (name = "distribution" ),
99+ Rule (name = "include-copyright" ),
100+ Rule (name = "liability" ),
101+ Rule (name = "warranty" ),
137102 ]
138- self .session .add_all (rules_to_add )
139- self .session .commit ()
103+ mock_db_session .query .return_value .filter .return_value .all .return_value = (
104+ mock_rules
105+ )
106+
107+ # Act
108+ populate_licenses_task (dry_run = False , db_session = mock_db_session )
140109
141- populate_licenses_task (dry_run = False , db_session = self .session )
110+ # Assert
111+ self .assertEqual (mock_db_session .merge .call_count , 2 )
112+ mock_db_session .rollback .assert_not_called ()
142113
143- licenses = self .session .query (License ).order_by (License .id ).all ()
144- self .assertEqual (len (licenses ), 2 )
114+ # Check that merge was called with correctly constructed License objects
115+ call_args_list = mock_db_session .merge .call_args_list
116+ merged_licenses = [arg .args [0 ] for arg in call_args_list ]
145117
146- # Check MIT License
147- mit_license = licenses [1 ]
148- self .assertEqual (mit_license .id , "MIT" )
118+ mit_license = next ((lic for lic in merged_licenses if lic .id == "MIT" ), None )
119+ self .assertIsNotNone (mit_license )
149120 self .assertEqual (mit_license .name , "MIT License" )
150121 self .assertTrue (mit_license .is_spdx )
151122 self .assertEqual (len (mit_license .rules ), 3 )
152- rule_names = sorted ([rule .name for rule in mit_license .rules ])
153- self .assertEqual (
154- rule_names , ["commercial-use" , "distribution" , "include-copyright" ]
155- )
156123
157- # Check BSD License
158- bsd_license = licenses [0 ]
159- self .assertEqual (bsd_license .id , "BSD-3-Clause" )
160- self .assertEqual (bsd_license .name , "BSD 3-Clause License" )
161- self .assertEqual (len (bsd_license .rules ), 3 )
162- rule_names = sorted ([rule .name for rule in bsd_license .rules ])
163- self .assertEqual (rule_names , ["commercial-use" , "liability" , "warranty" ])
124+ @patch ("tasks.licenses.populate_licenses.requests.get" )
125+ def test_populate_licenses_dry_run (self , mock_get ):
126+ """Test that no database changes are made during a dry run."""
127+ # Arrange
128+ self ._mock_requests_get (mock_get )
129+ mock_db_session = MagicMock ()
130+
131+ # Act
132+ populate_licenses_task (dry_run = True , db_session = mock_db_session )
133+
134+ # Assert
135+ mock_db_session .get .assert_not_called ()
136+ mock_db_session .merge .assert_not_called ()
137+ mock_db_session .rollback .assert_not_called ()
164138
165139 @patch ("tasks.licenses.populate_licenses.requests.get" )
166- def test_update_existing_license (self , mock_get ):
140+ def test_request_exception_handling (self , mock_get ):
141+ """Test handling of a requests exception."""
142+ # Arrange
143+ mock_get .side_effect = requests .exceptions .RequestException ("Network Error" )
144+ mock_db_session = MagicMock ()
145+
146+ # Act & Assert
147+ with self .assertRaises (requests .exceptions .RequestException ):
148+ populate_licenses_task (dry_run = False , db_session = mock_db_session )
149+
150+ mock_db_session .merge .assert_not_called ()
151+ # Rollback is not called because the exception happens before the db try/except block
152+ mock_db_session .rollback .assert_not_called ()
153+
154+ @patch ("tasks.licenses.populate_licenses.requests.get" )
155+ def test_database_exception_handling (self , mock_get ):
156+ """Test handling of a database exception during merge."""
157+ # Arrange
167158 self ._mock_requests_get (mock_get )
159+ mock_db_session = MagicMock ()
160+ mock_db_session .get .return_value = None
161+ mock_db_session .merge .side_effect = SQLAlchemyError ("DB connection failed" )
168162
169- # Pre-populate license and a rule
170- existing_license = License (
171- id = "MIT" , name = "Old MIT Name" , url = "http://oldurl.com"
172- )
173- existing_rule = Rule (id = "private-use" , name = "private-use" , type = "permission" )
174- existing_license .rules .append (existing_rule )
175- self .session .add (existing_license )
176- self .session .commit ()
177-
178- # Run the task to update
179- populate_licenses_task (dry_run = False , db_session = self .session )
180-
181- updated_license = self .session .query (License ).filter_by (id = "MIT" ).one ()
182- self .assertEqual (updated_license .name , "MIT License" )
183- self .assertEqual (updated_license .url , "https://opensource.org/licenses/MIT" )
184- # Check that rules are updated, not appended
185- self .assertNotEqual (len (updated_license .rules ), 4 )
163+ # Act & Assert
164+ with self .assertRaises (SQLAlchemyError ):
165+ populate_licenses_task (dry_run = False , db_session = mock_db_session )
166+
167+ self .assertTrue (mock_db_session .merge .called )
168+ mock_db_session .rollback .assert_called_once ()
186169
187170
188171if __name__ == "__main__" :
0 commit comments