Skip to content

Commit 05a5a72

Browse files
committed
added test
1 parent af6317c commit 05a5a72

File tree

4 files changed

+338
-0
lines changed

4 files changed

+338
-0
lines changed

functions-python/tasks_executor/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,14 @@ To populate license rules:
7878
}
7979
}
8080
```
81+
82+
To populate licenses:
83+
84+
```json
85+
{
86+
"task": "populate_licenses",
87+
"payload": {
88+
"dry_run": true
89+
}
90+
}
91+
```
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import logging
2+
3+
import requests
4+
from shared.database.database import with_db_session
5+
from shared.database_gen.sqlacodegen_models import License, Rule
6+
7+
LICENSES_API_URL = (
8+
"https://api.github.com/repos/MobilityData/licenses-aas/contents/data/licenses"
9+
)
10+
11+
12+
def populate_licenses_handler(payload):
13+
"""
14+
Handler for populating licenses.
15+
16+
Args:
17+
payload (dict): Incoming payload data.
18+
"""
19+
dry_run = get_parameters(payload)
20+
return populate_licenses_task(dry_run)
21+
22+
23+
@with_db_session
24+
def populate_licenses_task(dry_run, db_session):
25+
"""
26+
Populates licenses and their associated rules in the database.
27+
28+
Args:
29+
dry_run (bool): If True, simulates the operation without making changes.
30+
db_session: Database session for executing queries.
31+
"""
32+
logging.info("Starting populate_licenses_task with dry_run=%s", dry_run)
33+
34+
try:
35+
logging.info("Downloading license list from %s", LICENSES_API_URL)
36+
response = requests.get(LICENSES_API_URL, timeout=10)
37+
response.raise_for_status()
38+
files = response.json()
39+
40+
licenses_data = []
41+
for file_info in files:
42+
if file_info["type"] == "file" and file_info["name"].endswith(".json"):
43+
download_url = file_info["download_url"]
44+
logging.info("Downloading license from %s", download_url)
45+
license_response = requests.get(download_url, timeout=10)
46+
license_response.raise_for_status()
47+
licenses_data.append(license_response.json())
48+
49+
logging.info("Loaded %d licenses.", len(licenses_data))
50+
51+
if dry_run:
52+
logging.info("Dry run: would process %d licenses.", len(licenses_data))
53+
for license_data in licenses_data:
54+
logging.info("Dry run: processing license %d", len(licenses_data))
55+
else:
56+
for license_data in licenses_data:
57+
spdx_data = license_data.get("spdx")
58+
if not spdx_data:
59+
is_spdx = False
60+
else:
61+
is_spdx = True
62+
license_id = spdx_data.get("licenseId")
63+
if not license_id:
64+
logging.warning("Skipping record without licenseId.")
65+
continue
66+
67+
logging.info("Processing license %s", license_id)
68+
69+
license_object = db_session.get(License, license_id)
70+
if not license_object:
71+
license_object = License(id=license_id)
72+
license_object.is_spdx = is_spdx
73+
license_object.name = spdx_data.get("name")
74+
cross_ref_list = spdx_data.get("crossRef")
75+
if (
76+
cross_ref_list
77+
and isinstance(cross_ref_list, list)
78+
and cross_ref_list
79+
):
80+
license_object.url = cross_ref_list[0].get("url")
81+
else:
82+
license_object.url = None
83+
84+
license_object.content_txt = spdx_data.get("licenseText")
85+
license_object.content_html = spdx_data.get("licenseTextHtml")
86+
87+
# Clear existing rules to handle updates
88+
license_object.rules = []
89+
90+
all_rule_names = []
91+
for rule_type in ["permissions", "conditions", "limitations"]:
92+
all_rule_names.extend(license_data.get(rule_type, []))
93+
94+
all_rule_names = [
95+
name[:-1] if name.endswith("s") else name for name in all_rule_names
96+
]
97+
98+
if all_rule_names:
99+
rules = (
100+
db_session.query(Rule)
101+
.filter(Rule.name.in_(all_rule_names))
102+
.all()
103+
)
104+
license_object.rules.extend(rules)
105+
if len(rules) != len(all_rule_names):
106+
logging.warning(
107+
"License '%s': Found %d of %d rules in the database.",
108+
license_id,
109+
len(rules),
110+
len(all_rule_names),
111+
)
112+
# Merge the license object into the session. This handles both creating new licenses
113+
# and updating existing ones (upsert), including their rule associations.
114+
db_session.merge(license_object)
115+
116+
logging.info(
117+
"Successfully upserted licenses into the database.",
118+
)
119+
120+
except requests.exceptions.RequestException as e:
121+
logging.error("Failed to download licenses JSON file: %s", e)
122+
raise
123+
124+
125+
def get_parameters(payload):
126+
"""
127+
Get parameters from the payload.
128+
129+
Args:
130+
payload (dict): dictionary containing the payload data.
131+
Returns:
132+
bool: dry_run
133+
"""
134+
return payload.get("dry_run", False)

functions-python/tasks_executor/tests/tasks/populate_license_rules/test_populate_license_rules.py renamed to functions-python/tasks_executor/tests/tasks/populate_licenses_and_rules/test_populate_license_rules.py

File renamed without changes.
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import logging
2+
import unittest
3+
from unittest.mock import MagicMock, call, patch
4+
5+
import pytest
6+
import requests
7+
from shared.database.database import Session
8+
from shared.database_gen.sqlacodegen_models import License, Rule
9+
from tasks.licenses.populate_licenses import populate_licenses_task
10+
11+
# Mock data for GitHub API responses
12+
MOCK_LICENSE_LIST = [
13+
{
14+
"name": "MIT.json",
15+
"type": "file",
16+
"download_url": "http://mockurl/MIT.json",
17+
},
18+
{
19+
"name": "BSD-3-Clause.json",
20+
"type": "file",
21+
"download_url": "http://mockurl/BSD-3-Clause.json",
22+
},
23+
{
24+
"name": "no-spdx.json",
25+
"type": "file",
26+
"download_url": "http://mockurl/no-spdx.json",
27+
},
28+
{
29+
"name": "README.md",
30+
"type": "file",
31+
"download_url": "http://mockurl/README.md",
32+
},
33+
]
34+
35+
MOCK_LICENSE_MIT = {
36+
"spdx": {
37+
"licenseId": "MIT",
38+
"name": "MIT License",
39+
"crossRef": [{"url": "https://opensource.org/licenses/MIT"}],
40+
"licenseText": "MIT License text...",
41+
"licenseTextHtml": "<p>MIT License text...</p>",
42+
},
43+
"permissions": ["commercial-use", "distribution"],
44+
"conditions": ["include-copyright"],
45+
"limitations": [],
46+
}
47+
48+
MOCK_LICENSE_BSD = {
49+
"spdx": {
50+
"licenseId": "BSD-3-Clause",
51+
"name": "BSD 3-Clause License",
52+
"crossRef": [{"url": "https://opensource.org/licenses/BSD-3-Clause"}],
53+
"licenseText": "BSD license text...",
54+
"licenseTextHtml": "<p>BSD license text...</p>",
55+
},
56+
"permissions": ["commercial-use"],
57+
"conditions": [],
58+
"limitations": ["liability", "warranty"],
59+
}
60+
61+
MOCK_LICENSE_NO_SPDX = {"licenseId": "NO-SPDX-ID", "name": "No SPDX License"}
62+
63+
64+
class TestPopulateLicenses(unittest.TestCase):
65+
def setUp(self):
66+
# Create an in-memory SQLite database for testing
67+
self.session = Session(bind=self.engine)
68+
69+
def tearDown(self):
70+
self.session.close()
71+
72+
@classmethod
73+
def setUpClass(cls):
74+
from sqlalchemy import create_engine
75+
76+
from shared.database_gen.sqlacodegen_models import Base
77+
78+
cls.engine = create_engine("sqlite:///:memory:")
79+
Base.metadata.create_all(cls.engine)
80+
81+
@classmethod
82+
def tearDownClass(cls):
83+
from shared.database_gen.sqlacodegen_models import Base
84+
85+
Base.metadata.drop_all(cls.engine)
86+
87+
def _mock_requests_get(self, mock_get):
88+
"""Helper to configure mock for requests.get."""
89+
mock_responses = {
90+
"https://api.github.com/repos/MobilityData/licenses-aas/contents/data/licenses": MagicMock(
91+
json=lambda: MOCK_LICENSE_LIST
92+
),
93+
"http://mockurl/MIT.json": MagicMock(json=lambda: MOCK_LICENSE_MIT),
94+
"http://mockurl/BSD-3-Clause.json": MagicMock(
95+
json=lambda: MOCK_LICENSE_BSD
96+
),
97+
"http://mockurl/no-spdx.json": MagicMock(json=lambda: MOCK_LICENSE_NO_SPDX),
98+
}
99+
100+
def get_side_effect(url, timeout=None):
101+
if url in mock_responses:
102+
response = mock_responses[url]
103+
response.raise_for_status.return_value = None
104+
return response
105+
raise requests.exceptions.RequestException(f"URL not mocked: {url}")
106+
107+
mock_get.side_effect = get_side_effect
108+
109+
@patch("tasks.licenses.populate_licenses.requests.get")
110+
def test_populate_licenses_success(self, mock_get):
111+
self._mock_requests_get(mock_get)
112+
113+
# Pre-populate rules
114+
rules_to_add = [
115+
Rule(id="commercial-use", name="commercial-use", type="permission"),
116+
Rule(id="distribution", name="distribution", type="permission"),
117+
Rule(id="include-copyright", name="include-copyright", type="condition"),
118+
Rule(id="liability", name="liability", type="limitation"),
119+
Rule(id="warranty", name="warranty", type="limitation"),
120+
]
121+
self.session.add_all(rules_to_add)
122+
self.session.commit()
123+
124+
populate_licenses_task(dry_run=False, db_session=self.session)
125+
126+
licenses = self.session.query(License).order_by(License.id).all()
127+
self.assertEqual(len(licenses), 2)
128+
129+
# Check MIT License
130+
mit_license = licenses[1]
131+
self.assertEqual(mit_license.id, "MIT")
132+
self.assertEqual(mit_license.name, "MIT License")
133+
self.assertTrue(mit_license.is_spdx)
134+
self.assertEqual(len(mit_license.rules), 3)
135+
rule_names = sorted([rule.name for rule in mit_license.rules])
136+
self.assertEqual(
137+
rule_names, ["commercial-use", "distribution", "include-copyright"]
138+
)
139+
140+
# Check BSD License
141+
bsd_license = licenses[0]
142+
self.assertEqual(bsd_license.id, "BSD-3-Clause")
143+
self.assertEqual(bsd_license.name, "BSD 3-Clause License")
144+
self.assertEqual(len(bsd_license.rules), 3)
145+
rule_names = sorted([rule.name for rule in bsd_license.rules])
146+
self.assertEqual(rule_names, ["commercial-use", "liability", "warranty"])
147+
148+
@patch("tasks.licenses.populate_licenses.requests.get")
149+
def test_populate_licenses_dry_run(self, mock_get):
150+
self._mock_requests_get(mock_get)
151+
152+
with self.assertLogs("tasks.licenses.populate_licenses", level="INFO") as cm:
153+
populate_licenses_task(dry_run=True, db_session=self.session)
154+
self.assertIn(
155+
"INFO:tasks.licenses.populate_licenses:Dry run: would process 2 licenses.",
156+
cm.output,
157+
)
158+
159+
licenses_count = self.session.query(License).count()
160+
self.assertEqual(licenses_count, 0)
161+
162+
@patch("tasks.licenses.populate_licenses.requests.get")
163+
def test_populate_licenses_request_exception(self, mock_get):
164+
mock_get.side_effect = requests.exceptions.RequestException("Network Error")
165+
166+
with self.assertRaises(requests.exceptions.RequestException):
167+
populate_licenses_task(dry_run=False, db_session=self.session)
168+
169+
@patch("tasks.licenses.populate_licenses.requests.get")
170+
def test_update_existing_license(self, mock_get):
171+
self._mock_requests_get(mock_get)
172+
173+
# Pre-populate license and a rule
174+
existing_license = License(
175+
id="MIT", name="Old MIT Name", url="http://oldurl.com"
176+
)
177+
existing_rule = Rule(id="private-use", name="private-use", type="permission")
178+
existing_license.rules.append(existing_rule)
179+
self.session.add(existing_license)
180+
self.session.commit()
181+
182+
# Run the task to update
183+
populate_licenses_task(dry_run=False, db_session=self.session)
184+
185+
updated_license = self.session.query(License).filter_by(id="MIT").one()
186+
self.assertEqual(updated_license.name, "MIT License")
187+
self.assertEqual(updated_license.url, "https://opensource.org/licenses/MIT")
188+
# Check that rules are updated, not appended
189+
self.assertNotEqual(len(updated_license.rules), 4)
190+
191+
192+
if __name__ == "__main__":
193+
unittest.main()

0 commit comments

Comments
 (0)