|
| 1 | +import unittest |
| 2 | +from unittest.mock import MagicMock, patch |
| 3 | + |
| 4 | +import requests |
| 5 | + |
| 6 | +from shared.database_gen.sqlacodegen_models import Rule |
| 7 | +from tasks.licenses.populate_licenses import ( |
| 8 | + LICENSES_API_URL, |
| 9 | + populate_licenses_task, |
| 10 | +) |
| 11 | + |
| 12 | +# Mock data for GitHub API responses |
| 13 | +MOCK_LICENSE_LIST = [ |
| 14 | + { |
| 15 | + "name": "MIT.json", |
| 16 | + "type": "file", |
| 17 | + "download_url": "http://mockurl/MIT.json", |
| 18 | + }, |
| 19 | + { |
| 20 | + "name": "BSD-3-Clause.json", |
| 21 | + "type": "file", |
| 22 | + "download_url": "http://mockurl/BSD-3-Clause.json", |
| 23 | + }, |
| 24 | + { |
| 25 | + "name": "no-spdx.json", |
| 26 | + "type": "file", |
| 27 | + "download_url": "http://mockurl/no-spdx.json", |
| 28 | + }, |
| 29 | + { |
| 30 | + "name": "README.md", |
| 31 | + "type": "file", |
| 32 | + "download_url": "http://mockurl/README.md", |
| 33 | + }, |
| 34 | +] |
| 35 | + |
| 36 | +MOCK_LICENSE_MIT = { |
| 37 | + "spdx": { |
| 38 | + "licenseId": "MIT", |
| 39 | + "name": "MIT License", |
| 40 | + "crossRef": [{"url": "https://opensource.org/licenses/MIT"}], |
| 41 | + "licenseText": "MIT License text...", |
| 42 | + "licenseTextHtml": "<p>MIT License text...</p>", |
| 43 | + }, |
| 44 | + "permissions": ["commercial-use", "distribution"], |
| 45 | + "conditions": ["include-copyright"], |
| 46 | + "limitations": [], |
| 47 | +} |
| 48 | + |
| 49 | +MOCK_LICENSE_BSD = { |
| 50 | + "spdx": { |
| 51 | + "licenseId": "BSD-3-Clause", |
| 52 | + "name": "BSD 3-Clause License", |
| 53 | + "crossRef": [{"url": "https://opensource.org/licenses/BSD-3-Clause"}], |
| 54 | + "licenseText": "BSD license text...", |
| 55 | + "licenseTextHtml": "<p>BSD license text...</p>", |
| 56 | + }, |
| 57 | + "permissions": ["commercial-use"], |
| 58 | + "conditions": [], |
| 59 | + "limitations": ["liability", "warranty"], |
| 60 | +} |
| 61 | + |
| 62 | +MOCK_LICENSE_NO_SPDX = {"licenseId": "NO-SPDX-ID", "name": "No SPDX License"} |
| 63 | + |
| 64 | + |
| 65 | +class TestPopulateLicenses(unittest.TestCase): |
| 66 | + def _mock_requests_get(self, mock_get): |
| 67 | + """Helper to configure mock for requests.get.""" |
| 68 | + mock_responses = { |
| 69 | + LICENSES_API_URL: MagicMock(json=lambda: MOCK_LICENSE_LIST), |
| 70 | + "http://mockurl/MIT.json": MagicMock(json=lambda: MOCK_LICENSE_MIT), |
| 71 | + "http://mockurl/BSD-3-Clause.json": MagicMock( |
| 72 | + json=lambda: MOCK_LICENSE_BSD |
| 73 | + ), |
| 74 | + "http://mockurl/no-spdx.json": MagicMock(json=lambda: MOCK_LICENSE_NO_SPDX), |
| 75 | + } |
| 76 | + |
| 77 | + def get_side_effect(url, timeout=None): |
| 78 | + if url in mock_responses: |
| 79 | + response = mock_responses[url] |
| 80 | + response.raise_for_status.return_value = None |
| 81 | + return response |
| 82 | + raise requests.exceptions.RequestException(f"URL not mocked: {url}") |
| 83 | + |
| 84 | + mock_get.side_effect = get_side_effect |
| 85 | + |
| 86 | + @patch("tasks.licenses.populate_licenses.requests.get") |
| 87 | + def test_populate_licenses_success(self, mock_get): |
| 88 | + """Test successful population of licenses.""" |
| 89 | + # Arrange |
| 90 | + self._mock_requests_get(mock_get) |
| 91 | + mock_db_session = MagicMock() |
| 92 | + mock_db_session.get.return_value = None # Simulate no existing licenses |
| 93 | + |
| 94 | + # Mock the rules query to return only the rules that are requested. |
| 95 | + all_mock_rules = { |
| 96 | + "commercial-use": Rule(name="commercial-use"), |
| 97 | + "distribution": Rule(name="distribution"), |
| 98 | + "include-copyright": Rule(name="include-copyright"), |
| 99 | + "liability": Rule(name="liability"), |
| 100 | + "warranty": Rule(name="warranty"), |
| 101 | + } |
| 102 | + |
| 103 | + def filter_side_effect(filter_condition): |
| 104 | + # This simulates the `Rule.name.in_(...)` filter by inspecting the |
| 105 | + # requested names from the filter condition's right-hand side. |
| 106 | + requested_names = filter_condition.right.value |
| 107 | + mock_query_result = [ |
| 108 | + all_mock_rules[name] |
| 109 | + for name in requested_names |
| 110 | + if name in all_mock_rules |
| 111 | + ] |
| 112 | + mock_filter = MagicMock() |
| 113 | + mock_filter.all.return_value = mock_query_result |
| 114 | + return mock_filter |
| 115 | + |
| 116 | + mock_db_session.query.return_value.filter.side_effect = filter_side_effect |
| 117 | + |
| 118 | + # Act |
| 119 | + populate_licenses_task(dry_run=False, db_session=mock_db_session) |
| 120 | + |
| 121 | + # Assert |
| 122 | + self.assertEqual(mock_db_session.merge.call_count, 2) |
| 123 | + mock_db_session.rollback.assert_not_called() |
| 124 | + |
| 125 | + # Check that merge was called with correctly constructed License objects |
| 126 | + call_args_list = mock_db_session.merge.call_args_list |
| 127 | + merged_licenses = [arg.args[0] for arg in call_args_list] |
| 128 | + |
| 129 | + mit_license = next((lic for lic in merged_licenses if lic.id == "MIT"), None) |
| 130 | + self.assertIsNotNone(mit_license) |
| 131 | + self.assertEqual(mit_license.name, "MIT License") |
| 132 | + self.assertTrue(mit_license.is_spdx) |
| 133 | + self.assertEqual(len(mit_license.rules), 3) |
| 134 | + |
| 135 | + @patch("tasks.licenses.populate_licenses.requests.get") |
| 136 | + def test_populate_licenses_dry_run(self, mock_get): |
| 137 | + """Test that no database changes are made during a dry run.""" |
| 138 | + # Arrange |
| 139 | + self._mock_requests_get(mock_get) |
| 140 | + mock_db_session = MagicMock() |
| 141 | + |
| 142 | + # Act |
| 143 | + populate_licenses_task(dry_run=True, db_session=mock_db_session) |
| 144 | + |
| 145 | + # Assert |
| 146 | + mock_db_session.get.assert_not_called() |
| 147 | + mock_db_session.merge.assert_not_called() |
| 148 | + mock_db_session.rollback.assert_not_called() |
| 149 | + |
| 150 | + @patch("tasks.licenses.populate_licenses.requests.get") |
| 151 | + def test_request_exception_handling(self, mock_get): |
| 152 | + """Test handling of a requests exception.""" |
| 153 | + # Arrange |
| 154 | + mock_get.side_effect = requests.exceptions.RequestException("Network Error") |
| 155 | + mock_db_session = MagicMock() |
| 156 | + |
| 157 | + # Act & Assert |
| 158 | + with self.assertRaises(requests.exceptions.RequestException): |
| 159 | + populate_licenses_task(dry_run=False, db_session=mock_db_session) |
| 160 | + |
| 161 | + mock_db_session.merge.assert_not_called() |
| 162 | + # Rollback is not called because the exception happens before the db try/except block |
| 163 | + mock_db_session.rollback.assert_not_called() |
| 164 | + |
| 165 | + |
| 166 | +if __name__ == "__main__": |
| 167 | + unittest.main() |
0 commit comments