Skip to content

Commit 68e505f

Browse files
authored
fix: pmtiles error when missing latitude and longitude are missing on stops.txt (#1364)
1 parent d1843a6 commit 68e505f

File tree

13 files changed

+386
-101
lines changed

13 files changed

+386
-101
lines changed

functions-python/helpers/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ google-cloud-bigquery
2828

2929
# Additional package
3030
pycountry
31-
shapely
31+
shapely
32+
pandas

functions-python/helpers/tests/test_transform.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
from transform import to_boolean, get_nested_value
1+
import unittest
2+
3+
import pandas as pd
4+
5+
from transform import (
6+
to_boolean,
7+
get_nested_value,
8+
to_float,
9+
get_safe_value,
10+
get_safe_float,
11+
)
212

313

414
def test_to_boolean():
@@ -60,3 +70,87 @@ def test_get_nested_value():
6070
# Test case 9: Non-dictionary data
6171
assert get_nested_value("not a dict", ["a", "b", "c"]) is None
6272
assert get_nested_value("not a dict", ["a", "b", "c"], []) == []
73+
74+
75+
class TestToFloat(unittest.TestCase):
76+
def test_valid_float(self):
77+
self.assertEqual(to_float("3.14"), 3.14)
78+
self.assertEqual(to_float(2.5), 2.5)
79+
self.assertEqual(to_float("0"), 0.0)
80+
self.assertEqual(to_float(0), 0.0)
81+
82+
def test_invalid_float(self):
83+
self.assertIsNone(to_float("abc"))
84+
self.assertIsNone(to_float(None))
85+
self.assertIsNone(to_float(""))
86+
87+
def test_default_value(self):
88+
self.assertEqual(to_float("abc", default_value=1.23), 1.23)
89+
self.assertEqual(to_float(None, default_value=4.56), 4.56)
90+
self.assertEqual(to_float("", default_value=7.89), 7.89)
91+
92+
93+
class TestGetSafeValue(unittest.TestCase):
94+
def test_valid_value(self):
95+
row = {"name": " Alice "}
96+
self.assertEqual(get_safe_value(row, "name"), "Alice")
97+
98+
def test_missing_column(self):
99+
row = {"age": 30}
100+
self.assertIsNone(get_safe_value(row, "name"))
101+
102+
def test_empty_string(self):
103+
row = {"name": " "}
104+
self.assertIsNone(get_safe_value(row, "name"))
105+
106+
def test_nan_value(self):
107+
row = {"name": pd.NA}
108+
self.assertIsNone(get_safe_value(row, "name"))
109+
row = {"name": float("nan")}
110+
self.assertIsNone(get_safe_value(row, "name"))
111+
112+
def test_default_value(self):
113+
row = {"name": ""}
114+
self.assertEqual(
115+
get_safe_value(row, "name", default_value="default"), "default"
116+
)
117+
118+
119+
class TestGetSafeFloat(unittest.TestCase):
120+
def test_valid_float(self):
121+
row = {"value": "3.14"}
122+
self.assertEqual(get_safe_float(row, "value"), 3.14)
123+
row = {"value": 2.5}
124+
self.assertEqual(get_safe_float(row, "value"), 2.5)
125+
row = {"value": "0"}
126+
self.assertEqual(get_safe_float(row, "value"), 0.0)
127+
row = {"value": 0}
128+
self.assertEqual(get_safe_float(row, "value"), 0.0)
129+
130+
def test_missing_column(self):
131+
row = {"other": 1.23}
132+
self.assertIsNone(get_safe_float(row, "value"))
133+
134+
def test_empty_string(self):
135+
row = {"value": " "}
136+
self.assertIsNone(get_safe_float(row, "value"))
137+
138+
def test_nan_value(self):
139+
row = {"value": pd.NA}
140+
self.assertIsNone(get_safe_float(row, "value"))
141+
row = {"value": float("nan")}
142+
self.assertIsNone(get_safe_float(row, "value"))
143+
144+
def test_invalid_float(self):
145+
row = {"value": "abc"}
146+
self.assertIsNone(get_safe_float(row, "value"))
147+
row = {"value": None}
148+
self.assertIsNone(get_safe_float(row, "value"))
149+
150+
def test_default_value(self):
151+
row = {"value": ""}
152+
self.assertEqual(get_safe_float(row, "value", default_value=1.23), 1.23)
153+
row = {"value": "abc"}
154+
self.assertEqual(get_safe_float(row, "value", default_value=4.56), 4.56)
155+
row = {"value": None}
156+
self.assertEqual(get_safe_float(row, "value", default_value=7.89), 7.89)

functions-python/helpers/transform.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,40 @@ def to_enum(value, enum_class=None, default_value=None):
7777
except (ValueError, TypeError) as e:
7878
logging.warning("Failed to convert value to enum member: %s", e)
7979
return default_value
80+
81+
82+
def to_float(value, default_value: Optional[float] = None) -> Optional[float]:
83+
"""
84+
Convert a value to a float. If conversion fails, return the default value.
85+
"""
86+
try:
87+
return float(value)
88+
except (ValueError, TypeError):
89+
return default_value
90+
91+
92+
def get_safe_value(row, column_name, default_value=None) -> Optional[str]:
93+
"""
94+
Get a safe value from the row. If the value is missing or empty, return the default value.
95+
"""
96+
import pandas
97+
98+
value = row.get(column_name, None)
99+
if (
100+
value is None
101+
or pandas.isna(value)
102+
or (isinstance(value, str) and value.strip() == "")
103+
):
104+
return default_value
105+
return f"{value}".strip()
106+
107+
108+
def get_safe_float(row, column_name, default_value=None) -> Optional[float]:
109+
"""
110+
Get a safe float value from the row. If the value is missing or cannot be converted to float,
111+
"""
112+
safe_value = get_safe_value(row, column_name)
113+
try:
114+
return float(safe_value)
115+
except (ValueError, TypeError):
116+
return default_value
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import logging
2+
import os
3+
import socket
4+
import subprocess
5+
from typing import Dict
6+
import uuid
7+
from io import BytesIO
8+
9+
import requests
10+
11+
from shared.database.database import with_db_session
12+
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gbfsfeed
13+
from shared.helpers.runtime_metrics import track_metrics
14+
15+
from google.cloud import storage
16+
from sqlalchemy.orm import Session
17+
18+
19+
EMULATOR_STORAGE_BUCKET_NAME = "verifier"
20+
EMULATOR_HOST = "localhost"
21+
EMULATOR_STORAGE_PORT = 9023
22+
23+
24+
@track_metrics(metrics=("time", "memory", "cpu"))
25+
def download_to_local(
26+
feed_stable_id: str, url: str, filename: str, force_download: bool = False
27+
):
28+
"""
29+
Download a file from a URL and upload it to the Google Cloud Storage emulator.
30+
If the file already exists, it will not be downloaded again.
31+
Args:
32+
url (str): The URL to download the file from.
33+
filename (str): The name of the file to save in the emulator.
34+
"""
35+
if not url:
36+
return
37+
blob_path = f"{feed_stable_id}/{filename}"
38+
client = storage.Client()
39+
bucket = client.bucket(EMULATOR_STORAGE_BUCKET_NAME)
40+
blob = bucket.blob(blob_path)
41+
42+
# Check if the blob already exists in the emulator
43+
if not blob.exists() or force_download:
44+
logging.info(f"Downloading and uploading: {blob_path}")
45+
with requests.get(url, stream=True) as response:
46+
response.raise_for_status()
47+
blob.content_type = "application/json"
48+
# The file is downloaded into memory before uploading to ensure it's seekable.
49+
# Be careful with large files.
50+
data = BytesIO(response.content)
51+
blob.upload_from_file(data, rewind=True)
52+
else:
53+
logging.info(
54+
f"Blob already exists: gs://{EMULATOR_STORAGE_BUCKET_NAME}/{blob_path}"
55+
)
56+
57+
58+
@with_db_session
59+
def create_test_data(feed_stable_id: str, feed_dict: Dict, db_session: Session = None):
60+
"""
61+
Create test data in the database if it does not exist.
62+
This function is used to ensure that the reverse geolocation process has the necessary data to work with.
63+
"""
64+
# Here you would typically interact with your database to create the necessary test data
65+
# For this example, we will just log the action
66+
logging.info(f"Creating test data for {feed_stable_id} with data: {feed_dict}")
67+
model = Gtfsfeed if feed_dict["data_type"] == "gtfs" else Gbfsfeed
68+
local_feed = (
69+
db_session.query(model).filter(model.stable_id == feed_stable_id).one_or_none()
70+
)
71+
if not local_feed:
72+
local_feed = model(
73+
id=uuid.uuid4(),
74+
stable_id=feed_stable_id,
75+
data_type=feed_dict["data_type"],
76+
feed_name="Test Feed",
77+
note="This is a test feed created for reverse geolocation verification.",
78+
producer_url="https://files.mobilitydatabase.org/mdb-2014/mdb-2014-202508120303/mdb-2014-202508120303.zip",
79+
authentication_type="0",
80+
status="active",
81+
)
82+
db_session.add(local_feed)
83+
db_session.commit()
84+
85+
86+
def setup_local_storage_emulator():
87+
"""
88+
Setup the Google Cloud Storage emulator by creating the necessary bucket.
89+
"""
90+
from gcp_storage_emulator.server import create_server
91+
92+
os.environ[
93+
"STORAGE_EMULATOR_HOST"
94+
] = f"http://{EMULATOR_HOST}:{EMULATOR_STORAGE_PORT}"
95+
os.environ["DATASETS_BUCKET_NAME_GBFS"] = EMULATOR_STORAGE_BUCKET_NAME
96+
os.environ["DATASETS_BUCKET_NAME_GTFS"] = EMULATOR_STORAGE_BUCKET_NAME
97+
os.environ["DATASTORE_EMULATOR_HOST"] = "localhost:8081"
98+
server = create_server(
99+
host=EMULATOR_HOST,
100+
port=EMULATOR_STORAGE_PORT,
101+
in_memory=False,
102+
default_bucket=EMULATOR_STORAGE_BUCKET_NAME,
103+
)
104+
server.start()
105+
return server
106+
107+
108+
def shutdown_local_storage_emulator(server):
109+
"""Shutdown the Google Cloud Storage emulator."""
110+
server.stop()
111+
112+
113+
def is_datastore_emulator_running(host=EMULATOR_HOST, port=8081):
114+
"""Check if the Google Cloud Datastore emulator is running."""
115+
try:
116+
with socket.create_connection((host, port), timeout=2):
117+
return True
118+
except OSError:
119+
return False
120+
121+
122+
def start_datastore_emulator(project_id="test-project"):
123+
"""Start the Google Cloud Datastore emulator if it's not already running."""
124+
if not is_datastore_emulator_running():
125+
process = subprocess.Popen(
126+
[
127+
"gcloud",
128+
"beta",
129+
"emulators",
130+
"datastore",
131+
"start",
132+
"--project={}".format(project_id),
133+
"--host-port=localhost:8081",
134+
]
135+
)
136+
return process
137+
return None # Already running
138+
139+
140+
def shutdown_datastore_emulator(process):
141+
"""Shutdown the Google Cloud Datastore emulator."""
142+
if process:
143+
process.terminate()
144+
process.wait()

functions-python/pmtiles_builder/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ google-cloud-storage
2525
python-dotenv==1.0.0
2626
tippecanoe
2727
psutil
28+
pandas
2829

functions-python/pmtiles_builder/src/csv_cache.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
#
1616
import csv
1717
import os
18-
from shared.helpers.logger import get_logger
1918

19+
from gtfs import stop_txt_is_lat_log_required
20+
from shared.helpers.logger import get_logger
21+
from shared.helpers.transform import get_safe_value, get_safe_float
2022

2123
STOP_TIMES_FILE = "stop_times.txt"
2224
SHAPES_FILE = "shapes.txt"
@@ -127,10 +129,26 @@ def get_stops_from_trip(self, trip_id):
127129

128130
def get_coordinates_for_stop(self, stop_id) -> tuple[float, float] | None:
129131
if self.stop_to_coordinates is None:
130-
self.stop_to_coordinates = {
131-
s["stop_id"]: (float(s["stop_lon"]), float(s["stop_lat"]))
132-
for s in self.get_file(STOPS_FILE)
133-
}
132+
self.stop_to_coordinates = {}
133+
for s in self.get_file(STOPS_FILE):
134+
self.stop_to_coordinates.get(stop_id, [])
135+
row_stop_id = get_safe_value(s, "stop_id")
136+
row_stop_lon = get_safe_float(s, "stop_lon")
137+
row_stop_lat = get_safe_float(s, "stop_lat")
138+
if row_stop_id is None:
139+
self.logger.warning("Missing stop id: %s", s)
140+
continue
141+
if row_stop_lon is None or row_stop_lat is None:
142+
if stop_txt_is_lat_log_required(s):
143+
self.logger.warning(
144+
"Missing stop latitude and longitude : %s", s
145+
)
146+
else:
147+
self.logger.debug(
148+
"Missing optional stop latitude and longitude : %s", s
149+
)
150+
continue
151+
self.stop_to_coordinates[row_stop_id] = (row_stop_lon, row_stop_lat)
134152
return self.stop_to_coordinates.get(stop_id, None)
135153

136154
def set_workdir(self, workdir):
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from shared.helpers.transform import get_safe_value
2+
3+
# TODO: Move this file to a shared folder
4+
5+
6+
def stop_txt_is_lat_log_required(stop_row):
7+
"""
8+
Conditionally Required:
9+
- Required for locations which are stops (location_type=0), stations (location_type=1)
10+
or entrances/exits (location_type=2).
11+
- Optional for locations which are generic nodes (location_type=3) or boarding areas (location_type=4).
12+
13+
Args:
14+
row (dict): The data row to check.
15+
16+
Returns:
17+
bool: True if both latitude and longitude is required, False otherwise.
18+
"""
19+
location_type = get_safe_value(stop_row, "location_type", "0")
20+
return location_type in ("0", "1", "2")

functions-python/pmtiles_builder/src/gtfs_stops_to_geojson.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from collections import defaultdict
55

66
from csv_cache import CsvCache, ROUTES_FILE, TRIPS_FILE, STOP_TIMES_FILE, STOPS_FILE
7+
from gtfs import stop_txt_is_lat_log_required
78
from shared.helpers.runtime_metrics import track_metrics
9+
from shared.helpers.transform import get_safe_float
810

911
logger = logging.getLogger(__name__)
1012

@@ -60,10 +62,13 @@ def convert_stops_to_geojson(csv_cache: CsvCache, output_file):
6062
if (
6163
"stop_lat" not in row
6264
or "stop_lon" not in row
63-
or not row["stop_lat"]
64-
or not row["stop_lon"]
65+
or get_safe_float(row, "stop_lat") is None
66+
or get_safe_float(row, "stop_lon") is None
6567
):
66-
logger.warning(f"Missing coordinates for stop_id {stop_id}, skipping.")
68+
if stop_txt_is_lat_log_required(row):
69+
logger.warning(
70+
"Missing coordinates for stop_id {%s}, skipping.", stop_id
71+
)
6772
continue
6873

6974
# Routes serving this stop

0 commit comments

Comments
 (0)