diff --git a/api/src/shared/common/db_utils.py b/api/src/shared/common/db_utils.py index a2f646759..e1949aabd 100644 --- a/api/src/shared/common/db_utils.py +++ b/api/src/shared/common/db_utils.py @@ -1,5 +1,6 @@ import logging import os +import re from typing import Iterator, List, Dict, Optional from geoalchemy2 import WKTElement @@ -30,7 +31,7 @@ from .entity_type_enum import EntityType from .error_handling import raise_internal_http_validation_error, invalid_bounding_coordinates, invalid_bounding_method from .iter_utils import batched -from ..feed_filters.gbfs_feed_filter import GbfsFeedFilter, GbfsVersionFilter +from shared.feed_filters.gbfs_feed_filter import GbfsFeedFilter, GbfsVersionFilter def get_gtfs_feeds_query( @@ -511,3 +512,70 @@ def get_gbfs_feeds_query( ) ) return query + + +def normalize_url(url_column) -> str: + """ + Normalize a URL by removing the protocol (http:// or https://), 'www.' prefix, and trailing slash. + This function generates a SQLAlchemy expression that can be used in queries. + Args: + url_column: The SQLAlchemy column representing the URL. + Returns: + A SQLAlchemy expression that normalizes the URL. + """ + return func.regexp_replace( + func.regexp_replace( + func.regexp_replace(url_column, r"^https?://", "", "gi"), + r"^www\.", + "", + "gi", + ), + r"/$", + "", + "g", + ) + + +def normalize_url_str(url: str | None) -> str: + """Normalize a license URL for matching. + Steps: + - Trim whitespace and quotes + - Remove BOM characters + - Strip fragments and query parameters + - Remove scheme (http/https) and www prefix + - Lowercase the host + """ + u = (url or "").strip().strip("'\"").replace("\ufeff", "") + u = re.sub(r"#.*$", "", u) + u = re.sub(r"\?.*$", "", u) + u = re.sub(r"^https?://", "", u, flags=re.I) + u = re.sub(r"^www\.", "", u, flags=re.I) + # remove trailing slashes + u = re.sub(r"/+$", "", u) + if "/" in u: + host, rest = u.split("/", 1) + return host.lower() + "/" + rest + return u.lower() + + +def get_feed_query_by_normalized_url(url: str, db_session: Session) -> Query: + """ + Get a query to find the feed by normalized URL and exclude deprecated feeds. + Args: + url: The URL to normalize and search for. + db_session: SQLAlchemy session. + """ + return db_session.query(Feed).filter( + normalize_url_str(url) == func.lower(func.trim(normalize_url(Feed.producer_url))), + Feed.status != "deprecated", + ) + + +def get_feed_by_normalized_url(url: str, db_session: Session) -> Feed | None: + """ + Query the feed by normalized URL and exclude deprecated feeds. + Args: + url: The URL to normalize and search for. + db_session: SQLAlchemy session. + """ + return get_feed_query_by_normalized_url(url, db_session).first() diff --git a/api/src/shared/common/license_utils.py b/api/src/shared/common/license_utils.py new file mode 100644 index 000000000..89d24ca65 --- /dev/null +++ b/api/src/shared/common/license_utils.py @@ -0,0 +1,365 @@ +import logging +import re +from dataclasses import dataclass +from difflib import SequenceMatcher +from sqlalchemy.orm import Session +from sqlalchemy import select, func +from typing import List, Tuple, Optional + +from shared.common.db_utils import normalize_url, normalize_url_str +from shared.database_gen.sqlacodegen_models import License + + +@dataclass +class MatchingLicense: + """Response structure for license URL resolution.""" + + license_id: str + license_url: str + normalized_url: str + match_type: str + confidence: float + spdx_id: str | None = None + matched_name: str | None = None + matched_catalog_url: str | None = None + matched_source: str | None = None + notes: str | None = None + regional_id: str | None = None + + +# The COMMON_PATTERNS list contains tuples of (regex pattern, SPDX ID). +# It is used for heuristic matching of license URLs. +COMMON_PATTERNS = [ + (re.compile(r"opendatacommons\.org/licenses/odbl/1\.0/?", re.I), "ODbL-1.0"), + (re.compile(r"opendatacommons\.org/licenses/by/1\.0/?", re.I), "ODC-By-1.0"), + (re.compile(r"opendatacommons\.org/licenses/pddl/1\.0/?", re.I), "PDDL-1.0"), + (re.compile(r"opensource\.org/licenses/Apache-2\.0/?", re.I), "Apache-2.0"), + (re.compile(r"opensource\.org/licenses/MIT/?", re.I), "MIT"), + (re.compile(r"choosealicense\.com/licenses/mit/?", re.I), "MIT"), + (re.compile(r"choosealicense\.com/licenses/apache-2\.0/?", re.I), "Apache-2.0"), + # add Etalab / Québec, etc., once verified +] + + +def extract_host(url: str) -> str: + """Extract host only from normalized URL.""" + # if the url has protocol like http://, normalize_url_str should have removed it + normalized_url = normalize_url_str(url) + return normalized_url.split("/", 1)[0] if normalized_url else "" + + +def resolve_commons_creative_license(url: str) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """ + Resolve a Creative Commons license URL to an SPDX ID and an explanatory note. + + Returns: + (spdx_id, note) + - spdx_id: SPDX identifier string if resolved, else None + - note: additional context (e.g., locale port detected, version normalized), else None + - regional_id: locale/ported variant if present (e.g., 'CC-BY-2.1-jp'), else None + + Behavior & Rationale: + --------------------- + 1) Normalizes common CC URL variants + - Creative Commons pages often add suffixes like '/deed', '/deed.', '/legalcode', '/legalcode.'. + These suffixes are *presentation pages*, not distinct licenses. We strip them before matching. + + 2) Handles CC0 explicitly + - CC0 is under 'publicdomain/zero/1.0/'. The SPDX ID is 'CC0-1.0'. + + 3) Parses CC license family, version, and optional locale/port + - Pattern matched: 'creativecommons.org/licenses//[/]' + Examples: + https://creativecommons.org/licenses/by/4.0/ + https://creativecommons.org/licenses/by/2.5/de/ + https://creativecommons.org/licenses/by-nc-sa/3.0/jp/deed.ja + + - is one of: by, by-sa, by-nd, by-nc, by-nc-sa, by-nc-nd + - is a dotted number like 1.0, 2.0, 2.1, 2.5, 3.0, 4.0 + - historically denotes a jurisdiction "port" (e.g., 'jp', 'fr', 'de'). + + 4) Locale ports (jurisdiction-specific variants) are *not* in the SPDX License List + - Creative Commons no longer recommends using ported licenses. + - SPDX lists canonical (global) CC licenses and not the ported variants. + - If a locale segment is present, we keep the canonical family/version and add a note explaining that the port + was detected and ignored. + + 5) Version normalization + - SPDX includes certain CC versions (1.0, 2.0, 2.5, 3.0, 4.0) but *not* 2.1. + - Some historical ported pages use "2.1" (e.g., '.../by/2.1/jp/'). Map these to the closest SPDX-supported + equivalent: **2.1 → 2.0** for all CC BY* families. + - If an unexpected version appears, we attempt a conservative normalization to the closest known version, + preferring the nearest *lower or equal* recognized version (4.0, 3.0, 2.5, 2.0, 1.0). A note explains this. + + Examples: + - https://creativecommons.org/licenses/by/2.1/jp/ → + ("CC-BY-2.0", "Detected locale 'jp' and normalized 2.1 → 2.0.") + - https://creativecommons.org/licenses/by-sa/2.5/de → + ("CC-BY-SA-2.5", "Detected locale 'de' (ported license ignored).") + - https://creativecommons.org/licenses/by/4.0/ → ("CC-BY-4.0", None) + - https://creativecommons.org/publicdomain/zero/1.0/legalcode.en → ("CC0-1.0", None) + """ + + # Use your existing string-normalization utility (assumed to: + # - lowercase host, strip fragments/query/whitespace, collapse slashes, etc.) + n = normalize_url_str(url) + + # Remove presentation-only CC suffixes like '/legalcode', '/legalcode.xx', '/deed', '/deed.xx' + n = re.sub(r"/legalcode(\.[a-zA-Z\-]+)?$", "", n, flags=re.I) + n = re.sub(r"/deed(\.[a-zA-Z\-]+)?$", "", n, flags=re.I) + + # --- CC0 special case ----------------------------------------------------- + if re.search(r"creativecommons\.org/publicdomain/zero/1\.0/?$", n, re.I): + return "CC0-1.0", None, None + + # --- General CC licenses -------------------------------------------------- + # Capture family code, version, and optional locale (jurisdiction port). + # Locale historically tends to be 2 letters, but allow 2–5 just in case (e.g., 'pt-br'). + m = re.search( + r"creativecommons\.org/licenses/([a-z\-]+)/([\d\.]+)(?:/([a-z\-]{2,5}))?/?$", + n, + re.I, + ) + if not m: + return None, None, None + + code = m.group(1).lower() # e.g., 'by', 'by-sa', 'by-nc-nd' + ver_in = m.group(2) # e.g., '2.5' + locale = m.group(3) # e.g., 'jp', 'fr', 'de', or None + + # Map CC family code to SPDX base + family_map = { + "by": "CC-BY", + "by-sa": "CC-BY-SA", + "by-nd": "CC-BY-ND", + "by-nc": "CC-BY-NC", + "by-nc-sa": "CC-BY-NC-SA", + "by-nc-nd": "CC-BY-NC-ND", + } + base = family_map.get(code) + if not base: + return None, None, None + + note_parts = [] + + # If a locale/jurisdiction port is present, record a note and ignore it for SPDX ID construction. + if locale: + note_parts.append( + f"Detected locale/jurisdiction port '{locale}'. SPDX does not list ported CC licenses; using canonical ID." + ) + + # Normalize version to nearest SPDX-supported version. + # Direct map for commonly-seen CC versions and the special 2.1 → 2.0 case. + direct_version_map = { + "1.0": "1.0", + "2.0": "2.0", + "2.1": "2.0", # CC 2.1 ports are mapped to the closest SPDX-supported version (2.0) + "2.5": "2.5", + "3.0": "3.0", + "4.0": "4.0", + } + + if ver_in in direct_version_map: + ver_out = direct_version_map[ver_in] + if ver_out != ver_in: + note_parts.append(f"Normalized version {ver_in} → {ver_out} to match SPDX-supported versions.") + else: + # Fallback: choose the closest *lower or equal* known version. + # (Most unknowns should still land on an SPDX-supported canonical version.) + known = ["4.0", "3.0", "2.5", "2.0", "1.0"] + ver_out = None + try: + vin = float(ver_in) + # pick the highest known <= vin, else default to the lowest (1.0) + candidates = [kv for kv in known if vin >= float(kv)] + ver_out = candidates[0] if candidates else known[-1] + except ValueError: + # Non-numeric (unexpected) — choose the most modern canonical as a pragmatic default + ver_out = "4.0" + + note_parts.append(f"Unrecognized CC version '{ver_in}'. Chose closest canonical version '{ver_out}' for SPDX.") + + spdx_id = f"{base}-{ver_out}" + regional_id = f"{base}-{ver_in}-{locale.lower()}" if locale else None + note = " ".join(note_parts) if note_parts else None + return spdx_id, note, regional_id + + +def heuristic_spdx(url: str) -> str | None: + """Heuristic SPDX resolver based on common URL patterns.""" + for rx, spdx in COMMON_PATTERNS: + if rx.search(url) or rx.search(normalize_url_str(url)): + return spdx + return None + + +def fuzzy_ratio(a: str, b: str) -> float: + """Compute fuzzy similarity ratio between two strings.""" + return SequenceMatcher(None, a, b).ratio() + + +def resolve_fuzzy_match( + url_str: str, + url_host: str, + url_normalized: str, + fuzzy_threshold: float, + db_session: Session | None = None, + max_candidates: int | None = 5, +) -> List[MatchingLicense]: + """Fuzzy match license URL against same-host candidates in DB. + + Returns a sorted list of candidates (best first) whose similarity is >= fuzzy_threshold. + """ + if not db_session or not url_host: + return [] + + # Pull candidates from DB and filter by host in Python, based on License.url + db_licenses: list[License] = list(db_session.scalars(select(License))) + same_host: list[License] = [] + for lic in db_licenses: + if not getattr(lic, "url", None): + continue + if extract_host(normalize_url_str(lic.url)) == url_host: + same_host.append(lic) + + scored: list[tuple[float, License]] = [] + for lic in same_host: + lic_norm = normalize_url_str(lic.url) + score = fuzzy_ratio(url_normalized, lic_norm) + if score >= fuzzy_threshold: + scored.append((float(score), lic)) + + # Sort by descending score and optionally limit + scored.sort(key=lambda x: x[0], reverse=True) + if max_candidates is not None: + scored = scored[:max_candidates] + + results: List[MatchingLicense] = [] + for score, lic in scored: + results.append( + MatchingLicense( + license_id=lic.id, + license_url=url_str, + normalized_url=url_normalized, + spdx_id=lic.id, + match_type="fuzzy", + confidence=round(score, 3), + matched_name=lic.name, + matched_catalog_url=lic.url, + matched_source="db.license", + ) + ) + return results + + +def resolve_license( + license_url: str, + allow_fuzzy: bool = True, + fuzzy_threshold: float = 0.94, + db_session: Session | None = None, +) -> List[MatchingLicense]: + """Resolve a license URL to one or more SPDX candidates using multiple strategies. + + Strategies (in order of precedence): + 1) Exact match in DB(db.license) -> return [exact] + 2) Creative Commons resolver(cc-resolver) -> return [cc] + 3) Generic heuristics(pattern-heuristics) -> return [heuristic] + 4) Fuzzy (same host candidates) -> return [fuzzy...] + 5) No match -> return [none] + + Args: + license_url (str): The license URL to resolve. + allow_fuzzy (bool): Whether to allow fuzzy matching. + fuzzy_threshold (float): Minimum similarity ratio for fuzzy match. + db_session (Session | None): SQLAlchemy DB session. Required for DB-based strategies. + + Returns: + List[MatchingLicense]: Ordered list of resolution results. Empty if no match. + """ + url_str = str(license_url) + url_normalized = normalize_url_str(url_str) + url_host = extract_host(url_normalized) + + # 1) Exact hit in DB (compare normalized strings of known licenses) + exact_match: License | None = find_exact_match_license_url(url_normalized, db_session) if db_session else None + if exact_match: + return [ + MatchingLicense( + license_id=exact_match.id, + license_url=url_str, + normalized_url=url_normalized, + spdx_id=exact_match.id, + match_type="exact", + confidence=1.0, + matched_name=exact_match.name, + matched_catalog_url=exact_match.url, + matched_source="db.license", + ) + ] + + # 2) Creative Commons resolver + common_creative_match, notes, regional_id = resolve_commons_creative_license(url_str) + if common_creative_match: + cc_license: License | None = db_session.query(License).filter(License.id == common_creative_match).one_or_none() + if not cc_license: + logging.warning("CC license SPDX ID %s not found in DB", common_creative_match) + return [] + return [ + MatchingLicense( + license_id=cc_license.id, + license_url=url_str, + normalized_url=url_normalized, + spdx_id=common_creative_match, + match_type="heuristic", + confidence=0.99, + # Fill in matched_name with SPDX ID for lack of better info + matched_name=common_creative_match, + matched_catalog_url=None, + matched_source="cc-resolver", + notes=notes, + regional_id=regional_id, + ) + ] + + # 3) Generic heuristics + heuristic_match = heuristic_spdx(url_str) + if heuristic_match: + return [ + MatchingLicense( + license_id=heuristic_match, + license_url=url_str, + normalized_url=url_normalized, + spdx_id=heuristic_match, + match_type="heuristic", + confidence=0.95, + matched_name=heuristic_match, + matched_source="pattern-heuristics", + ) + ] + + # 4) Fuzzy (same host candidates only) + if allow_fuzzy and url_host and db_session is not None: + fuzzy_results = resolve_fuzzy_match( + url_str=url_str, + url_host=url_host, + url_normalized=url_normalized, + fuzzy_threshold=fuzzy_threshold, + db_session=db_session, + ) + if fuzzy_results: + return fuzzy_results + + # 5) No match + return [] + + +def find_exact_match_license_url(url_normalized: str, db_session: Session | None) -> License | None: + """Find exact match of normalized license URL in DB (License.url).""" + if not db_session: + return None + # Compare normalized strings using SQL functions on License.url + return ( + db_session.query(License) + .filter(normalize_url_str(url_normalized) == func.lower(func.trim(normalize_url(License.url)))) + .first() + ) diff --git a/api/tests/integration/test_database.py b/api/tests/integration/test_database.py index aa4535b24..1852fab34 100644 --- a/api/tests/integration/test_database.py +++ b/api/tests/integration/test_database.py @@ -7,7 +7,7 @@ from feeds.impl.datasets_api_impl import DatasetsApiImpl from feeds.impl.feeds_api_impl import FeedsApiImpl -from shared.common.db_utils import apply_bounding_filtering +from shared.common.db_utils import apply_bounding_filtering, normalize_url_str from shared.database.database import Database, generate_unique_id from shared.database_gen.sqlacodegen_models import Feature, Gtfsfeed from tests.test_utils.database import TEST_GTFS_FEED_STABLE_IDS, TEST_DATASET_STABLE_IDS @@ -157,3 +157,43 @@ def test_insert_and_select(): results_after_session_closed = db.select(new_session, Feature, conditions=[Feature.name == feature_name]) assert len(results_after_session_closed) == 1 assert results_after_session_closed[0][0].name == feature_name + + +@pytest.mark.parametrize( + "raw,expected", + [ + # Trim whitespace and surrounding quotes; remove scheme, www, query params and fragment; lowercase host + (" 'https://www.Example.com/path/page?query=1#section' ", "example.com/path/page"), + # Remove BOM characters and query + ("\ufeffhttps://example.com/license?x=1", "example.com/license"), + # Strip fragment + ("http://example.com/path#frag", "example.com/path"), + # Strip query + ("https://example.com/path?param=value", "example.com/path"), + # Remove trailing slashes + ("https://www.example.com/path///", "example.com/path"), + # Host only with scheme and www; trailing slash removed; host lowercased + ("http://www.EXAMPLE.com/", "example.com"), + # Path case preserved (only host lowercased) + ("https://Example.com/Case/Sensitive", "example.com/Case/Sensitive"), + # None becomes empty string + (None, ""), + # Blank / whitespace-only becomes empty string + (" ", ""), + # Quotes without scheme + ('"Example.com/path"', "example.com/path"), + ], +) +def test_normalize_url_str(raw, expected): + """Test normalize_url_str utility for all documented normalization steps. + Steps verified: + - Trim whitespace and quotes + - Remove BOM characters + - Strip fragments and query parameters + - Remove scheme (http/https) and www prefix + - Lowercase the host (only host) + - Remove trailing slashes + - Preserve path case + - Handle None / empty inputs + """ + assert normalize_url_str(raw) == expected diff --git a/api/tests/utils/test_license_utils.py b/api/tests/utils/test_license_utils.py new file mode 100644 index 000000000..03c333668 --- /dev/null +++ b/api/tests/utils/test_license_utils.py @@ -0,0 +1,217 @@ +"""Unit tests for license_utils module.""" +import unittest +from unittest.mock import MagicMock, patch + +from shared.common.license_utils import ( + extract_host, + resolve_commons_creative_license, + heuristic_spdx, + fuzzy_ratio, + resolve_fuzzy_match, + resolve_license, + find_exact_match_license_url, + MatchingLicense, +) +from shared.database_gen.sqlacodegen_models import License + + +class TestLicenseUtils(unittest.TestCase): + """Test cases for license-related helper functions.""" + + def setUp(self): + self.session = MagicMock() + + # --- extract_host --- + def test_extract_host_basic(self): + self.assertEqual(extract_host("example.com/path/to"), "example.com") + self.assertEqual(extract_host("example.com"), "example.com") + self.assertEqual(extract_host("http://example.com"), "example.com") + self.assertEqual(extract_host(" https://example.com"), "example.com") + self.assertEqual(extract_host(""), "") + + # --- resolve_commons_creative_license --- + def test_resolve_commons_creative_license_cc0(self): + url = "https://creativecommons.org/publicdomain/zero/1.0/" + spdx, note, regional_id = resolve_commons_creative_license(url) + self.assertEqual(spdx, "CC0-1.0") + self.assertIsNone(note) + self.assertIsNone(regional_id) + + def test_resolve_commons_creative_license_by_variants(self): + # BY with deed / legalcode suffixes and locale code + urls = [ + "https://creativecommons.org/licenses/by/4.0/", + "https://creativecommons.org/licenses/by/4.0/deed.en", + "https://creativecommons.org/licenses/by/4.0/legalcode", + ] + for u in urls: + spdx, note, regional_id = resolve_commons_creative_license(u) + self.assertEqual(spdx, "CC-BY-4.0") + self.assertIsNone(note) + self.assertIsNone(regional_id) + + def test_resolve_commons_creative_license_non_match(self): + spdx, note, regional_id = resolve_commons_creative_license("https://example.com/no/license") + self.assertIsNone(spdx) + self.assertIsNone(note) + self.assertIsNone(regional_id) + + def test_resolve_commons_creative_license_all_flavors(self): + cases = { + "https://creativecommons.org/licenses/by-sa/3.0/": "CC-BY-SA-3.0", + "https://creativecommons.org/licenses/by-nd/3.0/": "CC-BY-ND-3.0", + "https://creativecommons.org/licenses/by-nc/4.0/": "CC-BY-NC-4.0", + "https://creativecommons.org/licenses/by-nc-sa/4.0/": "CC-BY-NC-SA-4.0", + "https://creativecommons.org/licenses/by-nc-nd/4.0/": "CC-BY-NC-ND-4.0", + } + for url, expected in cases.items(): + spdx, note, regional_id = resolve_commons_creative_license(url) + self.assertEqual(spdx, expected) + self.assertIsNone(note) + self.assertIsNone(regional_id) + + def test_resolve_commons_creative_license_jp_variant(self): + # New behavior: return base SPDX and a note when locale/jurisdiction variants encountered + spdx, note, regional_id = resolve_commons_creative_license("https://creativecommons.org/licenses/by/2.1/jp/") + # Current implementation maps to CC-BY-2.0 with a note + self.assertEqual(spdx, "CC-BY-2.0") + self.assertIsNotNone(note) + # Be lenient about message contents + self.assertIn("jp", note.lower()) + self.assertEqual(regional_id, "CC-BY-2.1-jp") + + # --- heuristic_spdx --- + def test_heuristic_spdx_patterns(self): + self.assertEqual(heuristic_spdx("https://opensource.org/licenses/MIT"), "MIT") + self.assertEqual(heuristic_spdx("http://opensource.org/licenses/Apache-2.0"), "Apache-2.0") + self.assertEqual(heuristic_spdx("https://opendatacommons.org/licenses/odbl/1.0/"), "ODbL-1.0") + + def test_heuristic_spdx_no_match(self): + self.assertIsNone(heuristic_spdx("https://example.com/custom-license")) + + # --- fuzzy_ratio --- + def test_fuzzy_ratio_similarity(self): + a = "https://example.com/license/alpha" + b = "https://example.com/license/alpha" # identical + self.assertAlmostEqual(fuzzy_ratio(a, b), 1.0, places=5) + c = "https://example.com/license/beta" + ratio = fuzzy_ratio(a, c) + self.assertTrue(0 < ratio < 1) + + # --- resolve_fuzzy_match --- + def _make_license(self, id_: str, url: str, name: str = None, type_: str = "standard") -> License: + return License(id=id_, type=type_, name=name or id_, url=url) + + def test_resolve_fuzzy_match_no_session_or_host(self): + self.assertEqual(resolve_fuzzy_match("x", "", "x", 0.8, None), []) + + def test_resolve_fuzzy_match_host_and_threshold(self): + # Prepare licenses + lic1 = self._make_license("MIT", "opensource.org/licenses/MIT", "MIT") + lic2 = self._make_license("Apache-2.0", "opensource.org/licenses/Apache-2.0", "Apache") + lic3 = self._make_license("ODbL-1.0", "opendatacommons.org/licenses/odbl/1.0/", "ODbL") + # Only first two share the host 'opensource.org' + self.session.scalars.return_value = [lic1, lic2, lic3] + results = resolve_fuzzy_match( + url_str="https://opensource.org/licenses/mit/", + url_host="opensource.org", + url_normalized="opensource.org/licenses/mit", + fuzzy_threshold=0.6, + db_session=self.session, + max_candidates=2, + ) + self.assertGreaterEqual(len(results), 1) + self.assertTrue(all(r.match_type == "fuzzy" for r in results)) + self.assertTrue(all(r.matched_catalog_url.startswith("opensource.org") for r in results)) + + def test_resolve_fuzzy_match_applies_limit(self): + lic_list = [self._make_license(f"L{i}", f"host.com/path{i}") for i in range(10)] + self.session.scalars.return_value = lic_list + results = resolve_fuzzy_match( + url_str="host.com/path0", + url_host="host.com", + url_normalized="host.com/path0", + fuzzy_threshold=0.0, # accept all + db_session=self.session, + max_candidates=3, + ) + self.assertEqual(len(results), 3) + + # --- resolve_license --- + @patch("shared.common.license_utils.find_exact_match_license_url") + def test_resolve_license_exact(self, mock_find): + lic = self._make_license("MIT", "opensource.org/licenses/MIT", "MIT") + mock_find.return_value = lic + results = resolve_license("https://opensource.org/licenses/MIT", db_session=self.session) + self.assertEqual(len(results), 1) + self.assertEqual(results[0].match_type, "exact") + self.assertEqual(results[0].spdx_id, "MIT") + + @patch("shared.common.license_utils.find_exact_match_license_url", return_value=None) + def test_resolve_license_creative_commons(self, _mock_find): + # Provide session (implementation accesses db_session) but ensure exact path returns None + results = resolve_license("https://creativecommons.org/licenses/by/4.0/", db_session=self.session) + self.assertEqual(len(results), 1) + self.assertEqual(results[0].spdx_id, "CC-BY-4.0") + self.assertEqual(results[0].match_type, "heuristic") + + @patch("shared.common.license_utils.find_exact_match_license_url", return_value=None) + def test_resolve_license_generic_heuristic(self, _mock_find): + # Provide URL that matches heuristic patterns + results = resolve_license("https://choosealicense.com/licenses/mit/", db_session=self.session) + self.assertEqual(len(results), 1) + self.assertEqual(results[0].spdx_id, "MIT") + self.assertEqual(results[0].match_type, "heuristic") + + @patch("shared.common.license_utils.find_exact_match_license_url", return_value=None) + def test_resolve_license_fuzzy(self, _mock_find): + target_url = "https://licenses.example.org/pageA" + licA = self._make_license("LIC-A", "licenses.example.org/pageA", "License A") + licB = self._make_license("LIC-B", "licenses.example.org/pageB", "License B") + self.session.scalars.return_value = [licA, licB] + results = resolve_license(target_url, db_session=self.session, fuzzy_threshold=0.8) + self.assertTrue(results) + self.assertTrue(all(r.match_type == "fuzzy" for r in results)) + + @patch("shared.common.license_utils.find_exact_match_license_url", return_value=None) + def test_resolve_license_no_match(self, _mock_find): + # Provide unique host; no fuzzy allowed; should be empty + results = resolve_license( + "https://unknown.example.xyz/some/path", + db_session=self.session, + allow_fuzzy=False, + ) + self.assertEqual(results, []) + + # --- find_exact_match_license_url --- + def test_find_exact_match_license_url_hit(self): + expected_license = self._make_license("MIT", "opensource.org/licenses/MIT", "MIT") + self.session.query.return_value.filter.return_value.first.return_value = expected_license + result = find_exact_match_license_url("opensource.org/licenses/MIT", self.session) + self.assertIs(result, expected_license) + + def test_find_exact_match_license_url_miss(self): + self.session.query.return_value.filter.return_value.first.return_value = None + result = find_exact_match_license_url("opensource.org/licenses/Apache-2.0", self.session) + self.assertIsNone(result) + + # --- MatchingLicense dataclass simple instantiation --- + def test_matching_license_dataclass(self): + ml = MatchingLicense( + license_id="L1", + license_url="http://x.com/L1", + normalized_url="x.com/L1", + match_type="exact", + confidence=1.0, + spdx_id="L1", + matched_name="Name", + matched_catalog_url="x.com/L1", + matched_source="db.license", + ) + self.assertEqual(ml.license_id, "L1") + self.assertEqual(ml.match_type, "exact") + self.assertEqual(ml.confidence, 1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/docs/LICENSES.md b/docs/LICENSES.md new file mode 100644 index 000000000..9b2122681 --- /dev/null +++ b/docs/LICENSES.md @@ -0,0 +1,71 @@ +# Feed Licenses + +This page explains how license information is managed and automatically matched to public transit data feeds within the **Mobility Feed API**. +It also describes where the license data comes from and how the system determines which license applies to a given feed. + + +## Where the License Information Comes From + +License details are sourced from the [**Licenses-AAS**](https://github.com/MobilityData/licenses-aas) project, an open repository that organizes licenses according to: + +* **Permissions** – what users are allowed to do (for example, modify or share the data) +* **Conditions** – what users must do (for example, give credit to the author) +* **Limitations** – what users cannot do (for example, claim warranty or hold authors liable) + +Each license in that repository includes its official SPDX identifier, name, and a link to the full text of the license. + + +## How License Matching Works + +When a feed includes a **license URL**, the system tries to recognize and link it to a known open-data license. +This process ensures that each feed clearly indicates how its data can be used and shared. + +The matching process follows several steps: + +1. **Exact Match** +The system checks if the feed’s license URL exactly matches one from the license catalog. +This is the most reliable form of matching. + +2. **Creative Commons Resolver** +If no exact match is found, the system checks whether the URL represents a Creative Commons license +(including international and regional variants such as JP, FR, DE). +When detected, the resolver maps the URL to the correct SPDX ID and adds notes about regional versions if applicable. + +3. **Generic Heuristics** +If the URL follows a recognizable pattern (e.g., apache.org/licenses/LICENSE-2.0), +the system applies rule-based heuristics to infer the likely SPDX ID. + +4. **Fuzzy Match (same host only)** +If no deterministic match is found, the system compares the URL against known license URLs from the same domain +using string-similarity scoring. +This step captures minor variations such as trailing slashes, redirects, or small path differences. + +## Regional or Localized Licenses + +If a license URL points to a **localized version** (for example, a country-specific version of a Creative Commons license), +the system identifies the corresponding **standard SPDX license** and adds a note explaining the regional variant. + + +## Understanding the Matching Results + +Each match includes: + +* **License name and ID** – e.g., *MIT License (MIT)* +* **Match type** – how the license was identified (exact, fuzzy, or heuristic) +* **Confidence level** – how certain the system is about the match (higher = stronger match) +* **Notes** – any additional details, such as localized versions or domain inference + + +## License Rules + +Some licenses in the catalog include detailed **rules** that describe what users *can*, *must*, and *cannot* do under that license. +However, not all licenses currently have these rules defined. + +These rules are maintained in the [**Licenses-AAS**](https://github.com/MobilityData/licenses-aas) repository. +If you notice a license missing its rules, you’re encouraged to contribute by adding them, helping improve clarity and consistency for all users of open data. + + +## Keeping License Data Accurate + +The Licenses-AAS project is regularly updated to include new and revised open-data licenses. +This ensures that the Mobility Feed API always reflects the most current and reliable licensing information. diff --git a/functions-python/helpers/query_helper.py b/functions-python/helpers/query_helper.py index c6c3a9a48..073522b2e 100644 --- a/functions-python/helpers/query_helper.py +++ b/functions-python/helpers/query_helper.py @@ -1,4 +1,5 @@ import logging +import re from datetime import datetime from typing import Type @@ -216,23 +217,25 @@ def normalize_url(url_column) -> str: def normalize_url_str(url: str | None) -> str: + """Normalize a license URL for matching. + Steps: + - Trim whitespace and quotes + - Remove BOM characters + - Strip fragments and query parameters + - Remove scheme (http/https) and www prefix + - Lowercase the host """ - Normalize a URL string for Python-side comparison: - - strip whitespace - - remove http:// or https:// - - remove leading www. - - remove trailing slash - - lowercase - """ - import re - - if not url: - return "" - s = url.strip() - s = re.sub(r"^https?://", "", s, flags=re.I) - s = re.sub(r"^www\.", "", s, flags=re.I) - s = re.sub(r"/$", "", s) - return s.lower() + u = (url or "").strip().strip("'\"").replace("\ufeff", "") + u = re.sub(r"#.*$", "", u) + u = re.sub(r"\?.*$", "", u) + u = re.sub(r"^https?://", "", u, flags=re.I) + u = re.sub(r"^www\.", "", u, flags=re.I) + # remove trailing slashes + u = re.sub(r"/+$", "", u) + if "/" in u: + host, rest = u.split("/", 1) + return host.lower() + "/" + rest + return u.lower() def get_feed_by_normalized_url(url: str, db_session: Session) -> Feed | None: diff --git a/functions-python/tasks_executor/README.md b/functions-python/tasks_executor/README.md index efdb40654..a2575e811 100644 --- a/functions-python/tasks_executor/README.md +++ b/functions-python/tasks_executor/README.md @@ -89,3 +89,8 @@ To populate licenses: } } ``` + +## Response Content Type + +When the request includes the header `Accept: text/csv`, the server returns the response as a CSV file generated from the handler’s output. +If the header is not provided, the default response content type is `application/json`. \ No newline at end of file diff --git a/functions-python/tasks_executor/function_config.json b/functions-python/tasks_executor/function_config.json index b8099a2d9..100a0e39b 100644 --- a/functions-python/tasks_executor/function_config.json +++ b/functions-python/tasks_executor/function_config.json @@ -6,7 +6,7 @@ "memory": "8Gi", "trigger_http": true, "include_folders": ["helpers"], - "include_api_folders": ["database_gen", "database", "common"], + "include_api_folders": ["database_gen", "database", "common", "feed_filters"], "environment_variables": [ { "key": "DATASETS_BUCKET_NAME" diff --git a/functions-python/tasks_executor/requirements.txt b/functions-python/tasks_executor/requirements.txt index 563cd57f6..e6cf5c1ac 100644 --- a/functions-python/tasks_executor/requirements.txt +++ b/functions-python/tasks_executor/requirements.txt @@ -29,5 +29,7 @@ google-cloud-storage python-dotenv==1.0.0 pycountry + # Other utilities pandas +fastapi_filter diff --git a/functions-python/tasks_executor/src/main.py b/functions-python/tasks_executor/src/main.py index a1495b2f0..e96bcfe0b 100644 --- a/functions-python/tasks_executor/src/main.py +++ b/functions-python/tasks_executor/src/main.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import csv +import io from typing import Any, Final import flask @@ -24,6 +25,7 @@ from tasks.dataset_files.rebuild_missing_dataset_files import ( rebuild_missing_dataset_files_handler, ) +from tasks.licenses.license_matcher import match_license_handler from tasks.missing_bounding_boxes.rebuild_missing_bounding_boxes import ( rebuild_missing_bounding_boxes_handler, ) @@ -99,6 +101,10 @@ "description": "Populates licenses and license-rules in the database from a predefined JSON source.", "handler": populate_licenses_handler, }, + "match_licenses": { + "description": "Match licenses with feeds.", + "handler": match_license_handler, + }, "sync_transitfeeds_data": { "description": "Syncs data from TransitFeeds to the database.", "handler": sync_transitfeeds_handler, @@ -123,10 +129,40 @@ def get_task(request: flask.Request): task = request_json.get("task") if task not in tasks: raise ValueError("Task not supported: %s", task) + accept_content_type = request.headers.get("Accept", "application/json") payload = request_json.get("payload") if not payload: payload = {} - return task, payload + return task, payload, accept_content_type + + +def _to_csv(data) -> str: + if isinstance(data, str): + return data + if isinstance(data, dict): + output = io.StringIO() + writer = csv.DictWriter(output, fieldnames=list(data.keys())) + writer.writeheader() + writer.writerow(data) + return output.getvalue() + if isinstance(data, list): + if not data: + return "" + # Collect all keys to handle varying dict shapes + keys = set() + for row in data: + if isinstance(row, dict): + keys.update(row.keys()) + fieldnames = sorted(keys) + output = io.StringIO() + writer = csv.DictWriter(output, fieldnames=fieldnames) + writer.writeheader() + for row in data: + if isinstance(row, dict): + writer.writerow({k: row.get(k, "") for k in fieldnames}) + return output.getvalue() + # Fallback: stringify + return str(data) @functions_framework.http @@ -134,12 +170,23 @@ def tasks_executor(request: flask.Request) -> flask.Response: task: Any payload: Any try: - task, payload = get_task(request) + task, payload, accept_content_type = get_task(request) except ValueError as error: return flask.make_response(flask.jsonify({"error": str(error)}), 400) # Execute task handler = tasks[task]["handler"] try: - return flask.make_response(flask.jsonify(handler(payload=payload)), 200) + result = handler(payload=payload) + if accept_content_type == "text/csv": + csv_body = _to_csv(result) + response = flask.make_response(csv_body, 200) + response.headers["Content-Type"] = "text/csv; charset=utf-8" + response.headers[ + "Content-Disposition" + ] = "attachment; filename=task_result.csv" + return response + + # Default JSON response + return flask.make_response(flask.jsonify(result), 200) except Exception as error: return flask.make_response(flask.jsonify({"error": str(error)}), 500) diff --git a/functions-python/tasks_executor/src/tasks/licenses/license_matcher.py b/functions-python/tasks_executor/src/tasks/licenses/license_matcher.py new file mode 100644 index 000000000..b5afcdab9 --- /dev/null +++ b/functions-python/tasks_executor/src/tasks/licenses/license_matcher.py @@ -0,0 +1,148 @@ +import logging + +from sqlalchemy import asc, func +from sqlalchemy.orm import Session + +from shared.common.license_utils import resolve_license, MatchingLicense +from shared.database.database import with_db_session +from shared.database_gen.sqlacodegen_models import Feed, FeedLicenseChange +from shared.helpers.runtime_metrics import track_metrics + + +def get_parameters(payload): + dry_run = payload.get("dry_run", False) + only_unmatched = payload.get("only_unmatched", True) + feed_stable_id = payload.get("feed_stable_id", None) + return dry_run, only_unmatched, feed_stable_id + + +def match_license_handler(payload): + """ + Handler for matching licenses with feeds. + + Args: + payload (dict): Incoming payload data. + + """ + (dry_run, only_unmatched, feed_stable_id) = get_parameters(payload) + return match_licenses_task(dry_run, only_unmatched, feed_stable_id) + + +def assign_feed_license(feed: Feed, license_match: MatchingLicense): + """Assign the matched license to the feed and log the change if license is different.""" + if license_match.license_id != feed.license_id: + logging.info( + "New license match for feed %s: %s", + feed.stable_id, + license_match.license_id, + ) + feed.license_id = license_match.license_id + feed.license_notes = license_match.notes + feed_license_change: FeedLicenseChange = FeedLicenseChange( + feed_id=feed.id, + changed_at=None, # will be set by DB default + feed_license_url=feed.license_url, + matched_license_id=license_match.license_id, + confidence=license_match.confidence, + match_type=license_match.match_type, + matched_name=license_match.matched_name, + matched_catalog_url=license_match.matched_catalog_url, + matched_source=license_match.matched_source, + notes=license_match.notes, + regional_id=license_match.regional_id, + ) + feed.feed_license_changes.append(feed_license_change) + else: + logging.info("Feed %s license unchanged: %s", feed.stable_id, feed.license_id) + + +def process_feed(feed, dry_run, db_session): + """Process a single feed to match its license.""" + result = None + license_matches = resolve_license(feed.license_url, db_session=db_session) + if license_matches: + license_first_match = sorted( + license_matches, key=lambda x: x.confidence, reverse=True + )[0] + result = { + "feed_id": feed.id, + "feed_stable_id": feed.stable_id, + "feed_data_type": feed.data_type, + "feed_license_url": feed.license_url, + "matched_license_id": license_first_match.license_id, + "matched_spdx_id": license_first_match.spdx_id, + "confidence": license_first_match.confidence, + "match_type": license_first_match.match_type, + "matched_name": license_first_match.matched_name, + "matched_catalog_url": license_first_match.matched_catalog_url, + "matched_source": license_first_match.matched_source, + "notes": license_first_match.notes, + "regional_id": license_first_match.regional_id, + } + if not dry_run: + assign_feed_license(feed, license_first_match) + return result + + +@track_metrics(metrics=("time", "memory", "cpu")) +@with_db_session +def match_licenses_task( + dry_run: bool, + only_unmatched: bool, + feed_stable_id: str = None, + db_session: Session = None, +): + result = [] + if feed_stable_id: + feed = db_session.query(Feed).filter(Feed.stable_id == feed_stable_id).first() + if not feed: + logging.error("Feed with stable_id %s not found.", feed_stable_id) + raise ValueError(f"Feed with stable_id {feed_stable_id} not found.") + result.append(process_feed(feed, dry_run, db_session)) + else: + result = process_all_feeds(dry_run, only_unmatched, db_session) + return result + + +def process_all_feeds(dry_run: bool, only_unmatched: bool, db_session: Session | None): + result = [] + batch_size = 500 + last_id = None + i = 0 + total_processed = 0 + while True: + logging.info("Processing batch %d", i) + batch_query = db_session.query(Feed).filter( + "" != func.coalesce(Feed.license_url, "") + ) + if last_id is not None: + batch_query = batch_query.filter(Feed.id > last_id) + if only_unmatched: + batch_query = batch_query.filter(Feed.license_id.is_(None)) + batch = batch_query.order_by(asc(Feed.id)).limit(batch_size).all() + if not batch: + break + total_processed += len(batch) + for feed in batch: + feed_match = process_feed(feed, dry_run, db_session) + if feed_match: + result.append(feed_match) + if not dry_run: + # Flush the batch updates to the database + db_session.flush() + + last_id = batch[-1].id + db_session.expunge_all() + logging.info( + "Processed batch %d. Total processed %d, so far matched licenses: %d", + i, + total_processed, + len(result), + ) + i += 1 + logging.info( + "Total processed feeds %d. Total matched licenses: %d", + total_processed, + len(result), + ) + return result diff --git a/functions-python/tasks_executor/src/tasks/licenses/populate_licenses.py b/functions-python/tasks_executor/src/tasks/licenses/populate_licenses.py index 2d2551213..b2c65dfe8 100644 --- a/functions-python/tasks_executor/src/tasks/licenses/populate_licenses.py +++ b/functions-python/tasks_executor/src/tasks/licenses/populate_licenses.py @@ -86,7 +86,9 @@ def populate_licenses_task(dry_run, db_session): logging.info("Processing license %s", license_id) license_object = db_session.get(License, license_id) + is_new = False if not license_object: + is_new = True license_object = License(id=license_id) license_object.created_at = datetime.now(timezone.utc) license_object.is_spdx = is_spdx @@ -130,9 +132,10 @@ def populate_licenses_task(dry_run, db_session): len(rules), len(all_rule_names), ) - # Merge the license object into the session. This handles both creating new licenses - # and updating existing ones (upsert), including their rule associations. - db_session.merge(license_object) + # Merge the license object into the session. This handles updating existing licenses (upsert), + # including their rule associations. + if not is_new: + db_session.merge(license_object) logging.info( "Successfully upserted licenses into the database.", diff --git a/functions-python/tasks_executor/tests/license_matcher/test_license_matcher.py b/functions-python/tasks_executor/tests/license_matcher/test_license_matcher.py new file mode 100644 index 000000000..213a4216c --- /dev/null +++ b/functions-python/tasks_executor/tests/license_matcher/test_license_matcher.py @@ -0,0 +1,172 @@ +import unittest +from unittest.mock import patch, MagicMock + +from tasks.licenses.license_matcher import ( + get_parameters, + process_feed, + match_licenses_task, + match_license_handler, + process_all_feeds, +) + + +class TestLicenseMatcher(unittest.TestCase): + def test_get_parameters_defaults(self): + payload = {} + dry_run, only_unmatched, feed_stable_id = get_parameters(payload) + self.assertFalse(dry_run) + self.assertTrue(only_unmatched) + self.assertIsNone(feed_stable_id) + + def test_get_parameters_values(self): + payload = { + "dry_run": True, + "only_unmatched": False, + "feed_stable_id": "feed-123", + } + dry_run, only_unmatched, feed_stable_id = get_parameters(payload) + self.assertTrue(dry_run) + self.assertFalse(only_unmatched) + self.assertEqual(feed_stable_id, "feed-123") + + @patch("tasks.licenses.license_matcher.resolve_license") + def test_process_feed_with_match(self, mock_resolve): + feed = MagicMock() + feed.id = "feed1" + feed.stable_id = "stable1" + feed.data_type = "gtfs" + feed.license_url = "http://example.com/license" + feed.license_id = None + + match_obj = MagicMock() + match_obj.license_id = "MIT" + match_obj.spdx_id = "MIT" + match_obj.confidence = 0.95 + match_obj.match_type = "exact" + match_obj.matched_name = "MIT License" + match_obj.matched_catalog_url = "http://example.com/license" + match_obj.matched_source = "db.license" + mock_resolve.return_value = [match_obj] + + result = process_feed(feed, dry_run=False, db_session=MagicMock()) + self.assertIsNotNone(result) + self.assertEqual(result["matched_license_id"], "MIT") + self.assertEqual(feed.license_id, "MIT") + + @patch("tasks.licenses.license_matcher.resolve_license") + def test_process_feed_no_match(self, mock_resolve): + feed = MagicMock() + feed.id = "feed2" + feed.stable_id = "stable2" + feed.data_type = "gtfs" + feed.license_url = "http://example.com/license2" + mock_resolve.return_value = [] + result = process_feed(feed, dry_run=True, db_session=MagicMock()) + self.assertIsNone(result) + + @patch("tasks.licenses.license_matcher.process_feed") + def test_match_licenses_task_single_feed(self, mock_process_feed): + feed = MagicMock() + feed.stable_id = "stable1" + mock_process_feed.return_value = {"feed_id": "f1"} + + query_stub = MagicMock() + query_stub.filter.return_value = query_stub + query_stub.first.return_value = feed + + db_session = MagicMock() + db_session.query.return_value = query_stub + + result = match_licenses_task( + dry_run=True, + only_unmatched=True, + feed_stable_id="stable1", + db_session=db_session, + ) + self.assertEqual(result, [{"feed_id": "f1"}]) + mock_process_feed.assert_called_once() + + @patch("tasks.licenses.license_matcher.process_feed") + def test_match_license_handler_json(self, mock_process_feed): + mock_process_feed.return_value = {"feed_id": "f1"} + with patch( + "tasks.licenses.license_matcher.match_licenses_task", + return_value=[mock_process_feed.return_value], + ): + payload = {"dry_run": True, "feed_stable_id": "stable1"} + result = match_license_handler(payload) + self.assertEqual(result, [{"feed_id": "f1"}]) + + @patch("tasks.licenses.license_matcher.resolve_license") + def test_process_all_feeds_sequential(self, mock_resolve): + # Prepare feeds + feed1 = MagicMock() + feed1.id = "a" + feed1.stable_id = "sA" + feed1.data_type = "gtfs" + feed1.license_url = "http://example.com/l1" + feed1.license_id = None + feed2 = MagicMock() + feed2.id = "b" + feed2.stable_id = "sB" + feed2.data_type = "gtfs" + feed2.license_url = "http://example.com/l2" + feed2.license_id = None + + # MatchingLicense mocks + m1 = MagicMock() + m1.license_id = "MIT" + m1.spdx_id = "MIT" + m1.confidence = 0.9 + m1.match_type = "exact" + m1.matched_name = "MIT" + m1.matched_catalog_url = "u1" + m1.matched_source = "db.license" + m2 = MagicMock() + m2.license_id = "BSD" + m2.spdx_id = "BSD" + m2.confidence = 0.8 + m2.match_type = "exact" + m2.matched_name = "BSD" + m2.matched_catalog_url = "u2" + m2.matched_source = "db.license" + mock_resolve.side_effect = [[m1], [m2]] + + # Query stub returning one batch then empty + class QueryStub: + def __init__(self, batches): + self.batches = batches + self.calls = 0 + + def filter(self, *a, **k): + return self + + def order_by(self, *a, **k): + return self + + def limit(self, *a, **k): + return self + + def all(self): + if self.calls < len(self.batches): + res = self.batches[self.calls] + else: + res = [] + self.calls += 1 + return res + + db_session = MagicMock() + db_session.query.return_value = QueryStub([[feed1, feed2], []]) + db_session.flush.return_value = None + db_session.expunge_all.return_value = None + + matches = process_all_feeds( + dry_run=False, only_unmatched=True, db_session=db_session + ) + self.assertEqual(len(matches), 2) + self.assertEqual(feed1.license_id, "MIT") + self.assertEqual(feed2.license_id, "BSD") + + +if __name__ == "__main__": + unittest.main() diff --git a/functions-python/tasks_executor/tests/tasks/populate_licenses_and_rules/test_populate_licenses.py b/functions-python/tasks_executor/tests/tasks/populate_licenses_and_rules/test_populate_licenses.py index 173eed499..35f45b852 100644 --- a/functions-python/tasks_executor/tests/tasks/populate_licenses_and_rules/test_populate_licenses.py +++ b/functions-python/tasks_executor/tests/tasks/populate_licenses_and_rules/test_populate_licenses.py @@ -118,19 +118,20 @@ def filter_side_effect(filter_condition): # Act populate_licenses_task(dry_run=False, db_session=mock_db_session) - # Assert - self.assertEqual(mock_db_session.merge.call_count, 2) + # Assert: For two SPDX licenses, since they are new (get returns None), we add them, not merge + self.assertEqual(mock_db_session.add.call_count, 2) + mock_db_session.merge.assert_not_called() mock_db_session.rollback.assert_not_called() - # Check that merge was called with correctly constructed License objects - call_args_list = mock_db_session.merge.call_args_list - merged_licenses = [arg.args[0] for arg in call_args_list] - - mit_license = next((lic for lic in merged_licenses if lic.id == "MIT"), None) + # Inspect the License objects added + added_licenses = [call.args[0] for call in mock_db_session.add.call_args_list] + mit_license = next( + (lic for lic in added_licenses if getattr(lic, "id", None) == "MIT"), None + ) self.assertIsNotNone(mit_license) - self.assertEqual(mit_license.name, "MIT License") - self.assertTrue(mit_license.is_spdx) - self.assertEqual(len(mit_license.rules), 3) + self.assertEqual(getattr(mit_license, "name", None), "MIT License") + self.assertTrue(getattr(mit_license, "is_spdx", False)) + self.assertEqual(len(getattr(mit_license, "rules", [])), 3) @patch("tasks.licenses.populate_licenses.requests.get") def test_populate_licenses_dry_run(self, mock_get): diff --git a/functions-python/tasks_executor/tests/test_main.py b/functions-python/tasks_executor/tests/test_main.py index 1c8e4abad..bf8c87e53 100644 --- a/functions-python/tasks_executor/tests/test_main.py +++ b/functions-python/tasks_executor/tests/test_main.py @@ -23,24 +23,39 @@ class TestTasksExecutor(unittest.TestCase): @staticmethod - def create_mock_request(json_data): + def create_mock_request(json_data, headers=None): mock_request = MagicMock(spec=flask.Request) mock_request.get_json.return_value = json_data + mock_request.headers = headers if headers is not None else {} return mock_request def test_get_task_valid(self): request = TestTasksExecutor.create_mock_request( - {"task": "list_tasks", "payload": {"example": "data"}} + {"task": "list_tasks", "payload": {"example": "data"}}, + {"Accept": "application/json"}, ) - task, payload = get_task(request) + task, payload, accept_content_type = get_task(request) self.assertEqual(task, "list_tasks") self.assertEqual(payload, {"example": "data"}) + self.assertEqual(accept_content_type, "application/json") def test_get_task_valid_with_no_payload(self): request = TestTasksExecutor.create_mock_request({"task": "list_tasks"}) - task, payload = get_task(request) + task, payload, accept_content_type = get_task(request) self.assertEqual(task, "list_tasks") self.assertEqual(payload, {}) # Default empty payload + self.assertEqual( + accept_content_type, "application/json" + ) # Default content type + + def test_get_task_valid_with_accept_content_type(self): + request = TestTasksExecutor.create_mock_request( + {"task": "list_tasks"}, {"Accept": "text/csv"} + ) + task, payload, accept_content_type = get_task(request) + self.assertEqual(task, "list_tasks") + self.assertEqual(payload, {}) # Default empty payload + self.assertEqual(accept_content_type, "text/csv") def test_get_task_invalid_json(self): request = TestTasksExecutor.create_mock_request(None) diff --git a/liquibase/changelog.xml b/liquibase/changelog.xml index bd23b71ae..79c9f1201 100644 --- a/liquibase/changelog.xml +++ b/liquibase/changelog.xml @@ -84,4 +84,6 @@ + + diff --git a/liquibase/changes/feat_1433.sql b/liquibase/changes/feat_1433.sql new file mode 100644 index 000000000..bff0c4b6f --- /dev/null +++ b/liquibase/changes/feat_1433.sql @@ -0,0 +1,43 @@ +-- Add the 'license_id' and license_notes columns to the 'feed' table if it doesn't exist +ALTER TABLE feed ADD COLUMN IF NOT EXISTS license_id TEXT; +ALTER TABLE feed ADD COLUMN IF NOT EXISTS license_notes TEXT; + +-- Add a foreign key constraint to reference the 'licenses' table +ALTER TABLE feed + ADD CONSTRAINT fk_feed_license + FOREIGN KEY (license_id) REFERENCES license (id) + ON DELETE SET NULL + NOT VALID; + +ALTER TABLE feed VALIDATE CONSTRAINT fk_feed_license; + +-- Audit table for feed license matching changes +CREATE TABLE IF NOT EXISTS feed_license_change ( + id BIGSERIAL PRIMARY KEY, + feed_id VARCHAR(255) NOT NULL, + changed_at TIMESTAMPTZ NOT NULL DEFAULT now(), + feed_license_url TEXT, + matched_license_id TEXT, + confidence DOUBLE PRECISION, + match_type TEXT, + matched_name TEXT, + matched_catalog_url TEXT, + matched_source TEXT, + notes TEXT, + regional_id TEXT, + CONSTRAINT feed_license_change_feed_id_fkey FOREIGN KEY (feed_id) REFERENCES feed(id) ON DELETE CASCADE ON UPDATE NO ACTION, + CONSTRAINT feed_license_change_matched_license_id_fkey FOREIGN KEY (matched_license_id) REFERENCES license(id) ON DELETE SET NULL ON UPDATE NO ACTION +); + +-- Helpful indexes +CREATE INDEX IF NOT EXISTS ix_feed_license_id + ON feed (license_id); + +CREATE INDEX IF NOT EXISTS ix_flc_feed_changed_at + ON feed_license_change (feed_id, changed_at DESC); + +CREATE INDEX IF NOT EXISTS ix_flc_matched_license + ON feed_license_change (matched_license_id); + +CREATE INDEX IF NOT EXISTS ix_flc_match_type + ON feed_license_change (match_type); diff --git a/scripts/docker-localdb-rebuild-data.sh b/scripts/docker-localdb-rebuild-data.sh index bd3a7ee57..2082bab8d 100755 --- a/scripts/docker-localdb-rebuild-data.sh +++ b/scripts/docker-localdb-rebuild-data.sh @@ -105,6 +105,7 @@ if [ "$POPULATE_DB" = true ]; then full_path="$(readlink -f $SCRIPT_PATH/../data/$target_csv_file)" $SCRIPT_PATH/populate-db.sh $full_path printf "\n---------\nCompleted: populating catalog data.\n---------\n" + $SCRIPT_PATH/populate-licenses.sh fi if [ "$POPULATE_TEST_DATA" = true ]; then diff --git a/scripts/populate-licenses.sh b/scripts/populate-licenses.sh new file mode 100755 index 000000000..c558e7dde --- /dev/null +++ b/scripts/populate-licenses.sh @@ -0,0 +1,105 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Populate local database with license rules and licenses using the tasks_executor function code. +# This script will: +# 1. Run the license rules population task (creates/updates rules) +# 2. Run the licenses population task (links licenses to rules) +# Both steps can be executed in dry-run mode with --dry-run to verify actions. +# +# Usage: +# ./scripts/populate-licenses.sh # real execution +# ./scripts/populate-licenses.sh --dry-run # simulate without DB writes +# +# Requirements: +# - Local database running (docker compose up ...) +# - FEEDS_DATABASE_URL exported in environment or present in functions-python/tasks_executor/.env.local +# - Network access to GitHub (public repo MobilityData/licenses-aas) + +SCRIPT_DIR="$(cd "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +FX_NAME="tasks_executor" +FX_PATH="${REPO_ROOT}/functions-python/${FX_NAME}" +FX_SRC_PATH="${FX_PATH}/src" + +DRY_RUN=false +while [[ $# -gt 0 ]]; do + case "$1" in + --dry-run) DRY_RUN=true; shift ;; + -h|--help) + grep '^#' "$0" | sed 's/^# //' | sed '1,2d'; exit 0 ;; + *) echo "Unknown argument: $1" >&2; exit 1 ;; + esac +done + +if [[ ! -d "$FX_SRC_PATH" ]]; then + echo "ERROR: tasks_executor source not found at $FX_SRC_PATH" >&2 + exit 1 +fi + +# Ensure virtualenv (reuse function runner conventions) +if [[ ! -d "$FX_PATH/venv" ]]; then + echo "INFO: provisioning virtual environment (first run)" + pushd "$FX_PATH" >/dev/null + pip3 install --disable-pip-version-check virtualenv >/dev/null + python3 -m virtualenv venv >/dev/null + venv/bin/python -m pip install --disable-pip-version-check -r requirements.txt >/dev/null + popd >/dev/null +fi + +# Load local env vars if present (e.g., FEEDS_DATABASE_URL) +if [[ -f "$FX_PATH/.env.local" ]]; then + echo "INFO: Loading env vars from $FX_PATH/.env.local" + set -o allexport + # shellcheck disable=SC1090 + source "$FX_PATH/.env.local" + set +o allexport +fi + +# Also load repository-level config/.env.local for DB settings (preferred) +if [[ -f "$REPO_ROOT/config/.env.local" ]]; then + echo "INFO: Loading env vars from $REPO_ROOT/config/.env.local" + set -o allexport + # shellcheck disable=SC1090 + source "$REPO_ROOT/config/.env.local" + set +o allexport +fi + +# If FEEDS_DATABASE_URL is still not set, attempt to construct it from POSTGRES_* vars +if [[ -z "${FEEDS_DATABASE_URL:-}" ]]; then + if [[ -n "${POSTGRES_USER:-}" && -n "${POSTGRES_PASSWORD:-}" && -n "${POSTGRES_DB:-}" ]]; then + DB_HOST="${POSTGRES_HOST:-localhost}" + DB_PORT="${POSTGRES_PORT:-5432}" + FEEDS_DATABASE_URL="postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${DB_HOST}:${DB_PORT}/${POSTGRES_DB}" + export FEEDS_DATABASE_URL + echo "INFO: Constructed FEEDS_DATABASE_URL from POSTGRES_* variables." + else + echo "WARNING: FEEDS_DATABASE_URL is not set and POSTGRES_* vars are incomplete. The script will likely fail to connect to DB." >&2 + fi +fi + +# Log target DB (mask password) +if [[ -n "${FEEDS_DATABASE_URL:-}" ]]; then + MASKED_URL="${FEEDS_DATABASE_URL/:${POSTGRES_PASSWORD:-***}/:***}" + echo "INFO: Using FEEDS_DATABASE_URL=${MASKED_URL}" +fi + +PYTHON_BIN="$FX_PATH/venv/bin/python" +export PYTHONPATH="$FX_SRC_PATH" + +# Convert shell boolean to Python boolean literal +if [[ "$DRY_RUN" == "true" ]]; then PY_DRY=True; else PY_DRY=False; fi + +echo "INFO: Running populate_license_rules (dry_run=${DRY_RUN})" +"$PYTHON_BIN" - <