Skip to content

Commit 7068f2f

Browse files
committed
add unit tests and fix lint
1 parent 6e45d1d commit 7068f2f

File tree

2 files changed

+233
-1
lines changed

2 files changed

+233
-1
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
4+
from tasks.licenses.license_matcher import (
5+
get_parameters,
6+
get_csv_response,
7+
process_feed,
8+
match_licenses_task,
9+
match_license_handler,
10+
process_all_feeds,
11+
)
12+
13+
14+
class TestLicenseMatcher(unittest.TestCase):
15+
def test_get_parameters_defaults(self):
16+
payload = {}
17+
dry_run, only_unmatched, feed_stable_id, content_type = get_parameters(payload)
18+
self.assertFalse(dry_run)
19+
self.assertTrue(only_unmatched)
20+
self.assertIsNone(feed_stable_id)
21+
self.assertEqual(content_type, "application/json")
22+
23+
def test_get_parameters_values(self):
24+
payload = {
25+
"dry_run": True,
26+
"only_unmatched": False,
27+
"feed_stable_id": "feed-123",
28+
"content_type": "text/csv",
29+
}
30+
dry_run, only_unmatched, feed_stable_id, content_type = get_parameters(payload)
31+
self.assertTrue(dry_run)
32+
self.assertFalse(only_unmatched)
33+
self.assertEqual(feed_stable_id, "feed-123")
34+
self.assertEqual(content_type, "text/csv")
35+
36+
def test_get_csv_response(self):
37+
matches = [
38+
{
39+
"feed_id": "id1",
40+
"feed_stable_id": "stable1",
41+
"feed_data_type": "gtfs",
42+
"feed_license_url": "http://example.com/license1",
43+
"matched_license_id": "MIT",
44+
"matched_spdx_id": "MIT",
45+
"confidence": 0.99,
46+
"match_type": "exact",
47+
"matched_name": "MIT License",
48+
"matched_catalog_url": "http://example.com/license1",
49+
"matched_source": "db.license",
50+
}
51+
]
52+
csv_text = get_csv_response(matches)
53+
header = csv_text.splitlines()[0]
54+
# Current implementation concatenates md_url and feed_license_url in header
55+
self.assertIn("md_urlfeed_license_url", header)
56+
self.assertIn("feed_id,feed_stable_id,feed_data_type", header)
57+
self.assertIn("https://mobilitydatabase.org/feeds/stable1", csv_text)
58+
self.assertIn("MIT", csv_text)
59+
60+
@patch("tasks.licenses.license_matcher.resolve_license")
61+
def test_process_feed_with_match(self, mock_resolve):
62+
feed = MagicMock()
63+
feed.id = "feed1"
64+
feed.stable_id = "stable1"
65+
feed.data_type = "gtfs"
66+
feed.license_url = "http://example.com/license"
67+
feed.license_id = None
68+
69+
match_obj = MagicMock()
70+
match_obj.license_id = "MIT"
71+
match_obj.spdx_id = "MIT"
72+
match_obj.confidence = 0.95
73+
match_obj.match_type = "exact"
74+
match_obj.matched_name = "MIT License"
75+
match_obj.matched_catalog_url = "http://example.com/license"
76+
match_obj.matched_source = "db.license"
77+
mock_resolve.return_value = [match_obj]
78+
79+
result = process_feed(feed, dry_run=False, db_session=MagicMock())
80+
self.assertIsNotNone(result)
81+
self.assertEqual(result["matched_license_id"], "MIT")
82+
self.assertEqual(feed.license_id, "MIT")
83+
84+
@patch("tasks.licenses.license_matcher.resolve_license")
85+
def test_process_feed_no_match(self, mock_resolve):
86+
feed = MagicMock()
87+
feed.id = "feed2"
88+
feed.stable_id = "stable2"
89+
feed.data_type = "gtfs"
90+
feed.license_url = "http://example.com/license2"
91+
mock_resolve.return_value = []
92+
result = process_feed(feed, dry_run=True, db_session=MagicMock())
93+
self.assertIsNone(result)
94+
95+
@patch("tasks.licenses.license_matcher.process_feed")
96+
def test_match_licenses_task_single_feed(self, mock_process_feed):
97+
feed = MagicMock()
98+
feed.stable_id = "stable1"
99+
mock_process_feed.return_value = {"feed_id": "f1"}
100+
101+
query_stub = MagicMock()
102+
query_stub.filter.return_value = query_stub
103+
query_stub.first.return_value = feed
104+
105+
db_session = MagicMock()
106+
db_session.query.return_value = query_stub
107+
108+
result = match_licenses_task(
109+
dry_run=True,
110+
only_unmatched=True,
111+
feed_stable_id="stable1",
112+
db_session=db_session,
113+
)
114+
self.assertEqual(result, [{"feed_id": "f1"}])
115+
mock_process_feed.assert_called_once()
116+
117+
@patch("tasks.licenses.license_matcher.process_feed")
118+
def test_match_license_handler_csv(self, mock_process_feed):
119+
mock_process_feed.return_value = {
120+
"feed_id": "f1",
121+
"feed_stable_id": "stable1",
122+
"feed_data_type": "gtfs",
123+
"feed_license_url": "http://example.com/license",
124+
"matched_license_id": "MIT",
125+
"matched_spdx_id": "MIT",
126+
"confidence": 1.0,
127+
"match_type": "exact",
128+
"matched_name": "MIT License",
129+
"matched_catalog_url": "http://example.com/license",
130+
"matched_source": "db.license",
131+
}
132+
133+
with patch(
134+
"tasks.licenses.license_matcher.match_licenses_task",
135+
return_value=[mock_process_feed.return_value],
136+
):
137+
payload = {
138+
"dry_run": True,
139+
"feed_stable_id": "stable1",
140+
"content_type": "text/csv",
141+
}
142+
csv_output = match_license_handler(payload)
143+
self.assertIn("feed_stable_id", csv_output.splitlines()[0])
144+
self.assertIn("stable1", csv_output)
145+
self.assertIn("MIT", csv_output)
146+
147+
@patch("tasks.licenses.license_matcher.process_feed")
148+
def test_match_license_handler_json(self, mock_process_feed):
149+
mock_process_feed.return_value = {"feed_id": "f1"}
150+
with patch(
151+
"tasks.licenses.license_matcher.match_licenses_task",
152+
return_value=[mock_process_feed.return_value],
153+
):
154+
payload = {"dry_run": True, "feed_stable_id": "stable1"}
155+
result = match_license_handler(payload)
156+
self.assertEqual(result, [{"feed_id": "f1"}])
157+
158+
@patch("tasks.licenses.license_matcher.resolve_license")
159+
def test_process_all_feeds_sequential(self, mock_resolve):
160+
# Prepare feeds
161+
feed1 = MagicMock()
162+
feed1.id = "a"
163+
feed1.stable_id = "sA"
164+
feed1.data_type = "gtfs"
165+
feed1.license_url = "http://example.com/l1"
166+
feed1.license_id = None
167+
feed2 = MagicMock()
168+
feed2.id = "b"
169+
feed2.stable_id = "sB"
170+
feed2.data_type = "gtfs"
171+
feed2.license_url = "http://example.com/l2"
172+
feed2.license_id = None
173+
174+
# MatchingLicense mocks
175+
m1 = MagicMock()
176+
m1.license_id = "MIT"
177+
m1.spdx_id = "MIT"
178+
m1.confidence = 0.9
179+
m1.match_type = "exact"
180+
m1.matched_name = "MIT"
181+
m1.matched_catalog_url = "u1"
182+
m1.matched_source = "db.license"
183+
m2 = MagicMock()
184+
m2.license_id = "BSD"
185+
m2.spdx_id = "BSD"
186+
m2.confidence = 0.8
187+
m2.match_type = "exact"
188+
m2.matched_name = "BSD"
189+
m2.matched_catalog_url = "u2"
190+
m2.matched_source = "db.license"
191+
mock_resolve.side_effect = [[m1], [m2]]
192+
193+
# Query stub returning one batch then empty
194+
class QueryStub:
195+
def __init__(self, batches):
196+
self.batches = batches
197+
self.calls = 0
198+
199+
def filter(self, *a, **k):
200+
return self
201+
202+
def order_by(self, *a, **k):
203+
return self
204+
205+
def limit(self, *a, **k):
206+
return self
207+
208+
def all(self):
209+
if self.calls < len(self.batches):
210+
res = self.batches[self.calls]
211+
else:
212+
res = []
213+
self.calls += 1
214+
return res
215+
216+
db_session = MagicMock()
217+
db_session.query.return_value = QueryStub([[feed1, feed2], []])
218+
db_session.flush.return_value = None
219+
db_session.expunge_all.return_value = None
220+
221+
matches = process_all_feeds(
222+
dry_run=False, only_unmatched=True, db_session=db_session
223+
)
224+
self.assertEqual(len(matches), 2)
225+
self.assertEqual(feed1.license_id, "MIT")
226+
self.assertEqual(feed2.license_id, "BSD")
227+
228+
229+
if __name__ == "__main__":
230+
unittest.main()

functions-python/tasks_executor/tests/tasks/populate_licenses_and_rules/test_populate_licenses.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def filter_side_effect(filter_condition):
125125

126126
# Inspect the License objects added
127127
added_licenses = [call.args[0] for call in mock_db_session.add.call_args_list]
128-
mit_license = next((lic for lic in added_licenses if getattr(lic, "id", None) == "MIT"), None)
128+
mit_license = next(
129+
(lic for lic in added_licenses if getattr(lic, "id", None) == "MIT"), None
130+
)
129131
self.assertIsNotNone(mit_license)
130132
self.assertEqual(getattr(mit_license, "name", None), "MIT License")
131133
self.assertTrue(getattr(mit_license, "is_spdx", False))

0 commit comments

Comments
 (0)