Skip to content

Commit 6d06ca8

Browse files
authored
fix: avoid exception when country code is not found (#1074)
1 parent f7fb396 commit 6d06ca8

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

api/src/scripts/populate_db_gtfs.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
22
import traceback
3-
from typing import TYPE_CHECKING
43
from datetime import datetime
4+
from typing import TYPE_CHECKING
55

66
import pycountry
77
import pytz
88
from sqlalchemy import text
99

10+
from scripts.load_dataset_on_create import publish_all
11+
from scripts.populate_db import DatabasePopulateHelper, set_up_configs
1012
from shared.database.database import generate_unique_id, configure_polymorphic_mappers
1113
from shared.database_gen.sqlacodegen_models import (
1214
Entitytype,
@@ -16,8 +18,6 @@
1618
Redirectingid,
1719
t_feedsearch,
1820
)
19-
from scripts.populate_db import DatabasePopulateHelper, set_up_configs
20-
from scripts.load_dataset_on_create import publish_all
2121
from utils.data_utils import set_up_defaults
2222

2323
if TYPE_CHECKING:
@@ -61,9 +61,11 @@ def get_stable_id(self, row):
6161
return f'mdb-{self.get_safe_value(row, "mdb_source_id", "")}'
6262

6363
def get_country(self, country_code):
64+
country = None
6465
if country_code:
65-
return pycountry.countries.get(alpha_2=country_code).name
66-
return None
66+
country = pycountry.countries.get(alpha_2=country_code)
67+
country = country.name if country else None
68+
return country
6769

6870
def populate_location(self, session, feed, row, stable_id):
6971
"""
@@ -88,7 +90,10 @@ def populate_location(self, session, feed, row, stable_id):
8890
if location
8991
else Location(
9092
id=location_id,
91-
country_code=country_code,
93+
# Country code should be short.
94+
# If too long it might be an error
95+
# (like it could be the country name instead of code).
96+
country_code=country_code if country_code and len(country_code) <= 3 else None,
9297
subdivision_name=subdivision_name,
9398
municipality=municipality,
9499
country=country,

functions-python/reverse_geolocation/src/location_group_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def __init__(self, location_group: Osmlocationgroup, stops_count: int):
7575

7676
def country(self) -> str:
7777
"""Returns the country name of the LocationGroup."""
78-
return pycountry.countries.get(alpha_2=self.iso_3166_1_code).name
78+
country = pycountry.countries.get(alpha_2=self.iso_3166_1_code)
79+
return country.name if country else None
7980

8081
def location_id(self) -> str:
8182
"""Returns the location ID of the LocationGroup."""

functions-python/reverse_geolocation/tests/test_location_group_utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import unittest
33
from unittest.mock import patch, MagicMock
44

5+
import pytest
56
from faker import Faker
67
from geoalchemy2 import WKTElement
78

9+
from location_group_utils import GeopolygonAggregate
810
from shared.database_gen.sqlacodegen_models import Geopolygon, Osmlocationgroup
911

10-
1112
faker = Faker()
1213

1314

@@ -91,3 +92,26 @@ def test_geopolygon_aggregate(self):
9192
geopolygon_aggregate_2 = GeopolygonAggregate(location_group, 1)
9293
geopolygon_aggregate.merge(geopolygon_aggregate_2)
9394
self.assertEqual(geopolygon_aggregate.stop_count, 2)
95+
96+
97+
@pytest.mark.parametrize(
98+
"values",
99+
[{"value": "CA", "expected": "Canada"}, {"value": "Canada", "expected": None}],
100+
)
101+
def test_location_country(values):
102+
"""
103+
Test the country function with cases with valid and invalid ISO 3166_1 code
104+
"""
105+
geopolygon = Geopolygon(
106+
osm_id=1,
107+
admin_level=2,
108+
name=values.get("value"),
109+
iso_3166_1_code=values.get("value"),
110+
geometry=WKTElement("POINT(-73.5673 45.5017)", srid=4326),
111+
)
112+
location_group = Osmlocationgroup(
113+
group_id="1.1.1", group_name="Canada, Ontario", osms=[geopolygon]
114+
)
115+
geopolygon_aggregate = GeopolygonAggregate(location_group, 1)
116+
117+
assert geopolygon_aggregate.country() == values.get("expected")

0 commit comments

Comments
 (0)