Skip to content

Commit a552071

Browse files
authored
feat[pmtiles]: support multiple shapes per routes (#1365)
1 parent e6a721b commit a552071

File tree

7 files changed

+310
-100
lines changed

7 files changed

+310
-100
lines changed

functions-python/helpers/tests/test_transform.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
to_float,
99
get_safe_value,
1010
get_safe_float,
11+
get_safe_int,
1112
)
1213

1314

@@ -154,3 +155,21 @@ def test_default_value(self):
154155
self.assertEqual(get_safe_float(row, "value", default_value=4.56), 4.56)
155156
row = {"value": None}
156157
self.assertEqual(get_safe_float(row, "value", default_value=7.89), 7.89)
158+
159+
160+
class TestGetSafeInt(unittest.TestCase):
161+
def test_valid_int(self):
162+
row = {"value": "42"}
163+
self.assertEqual(get_safe_int(row, "value"), 42)
164+
165+
def test_invalid_int(self):
166+
row = {"value": "abc"}
167+
self.assertIsNone(get_safe_int(row, "value"))
168+
169+
def test_missing_key(self):
170+
row = {}
171+
self.assertIsNone(get_safe_int(row, "value"))
172+
173+
def test_empty_string(self):
174+
row = {"value": ""}
175+
self.assertIsNone(get_safe_int(row, "value"))

functions-python/helpers/transform.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def to_float(value, default_value: Optional[float] = None) -> Optional[float]:
8989
return default_value
9090

9191

92-
def get_safe_value(row, column_name, default_value=None) -> Optional[str]:
92+
def get_safe_value(row, column_name, default_value: str = None) -> Optional[str]:
9393
"""
9494
Get a safe value from the row. If the value is missing or empty, return the default value.
9595
"""
@@ -105,12 +105,23 @@ def get_safe_value(row, column_name, default_value=None) -> Optional[str]:
105105
return f"{value}".strip()
106106

107107

108-
def get_safe_float(row, column_name, default_value=None) -> Optional[float]:
108+
def get_safe_float(row, column_name, default_value: float = None) -> Optional[float]:
109109
"""
110-
Get a safe float value from the row. If the value is missing or cannot be converted to float,
110+
Get a safe float value from the row. If the value is missing or cannot be converted to float.
111111
"""
112112
safe_value = get_safe_value(row, column_name)
113113
try:
114114
return float(safe_value)
115115
except (ValueError, TypeError):
116116
return default_value
117+
118+
119+
def get_safe_int(row, column_name, default_value: int = None) -> Optional[int]:
120+
"""
121+
Get a safe int value from the row. If the value is missing or cannot be converted to int.
122+
"""
123+
safe_value = get_safe_value(row, column_name)
124+
try:
125+
return int(safe_value)
126+
except (ValueError, TypeError):
127+
return default_value
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
Faker
22
pytest~=7.4.3
33
urllib3-mock
4-
requests-mock
4+
requests-mock
5+
psutil
6+
gcp_storage_emulator

functions-python/pmtiles_builder/src/csv_cache.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#
1616
import csv
1717
import os
18+
from typing import TypedDict, List, Dict
19+
1820

1921
from gtfs import stop_txt_is_lat_log_required
2022
from shared.helpers.logger import get_logger
@@ -28,6 +30,11 @@
2830
AGENCY_FILE = "agency.txt"
2931

3032

33+
class ShapeTrips(TypedDict):
34+
shape_id: str
35+
trip_ids: List[str]
36+
37+
3138
class CsvCache:
3239
"""
3340
CsvCache provides cached access to GTFS CSV files in a specified working directory.
@@ -49,11 +56,12 @@ def __init__(
4956
self.workdir = workdir
5057

5158
self.file_data = {}
52-
self.trip_to_stops = None
59+
self.trip_to_stops: Dict[str, List[str]] = None
5360
self.route_to_trip = None
54-
self.route_to_shape = None
61+
self.route_to_shape: Dict[str, Dict[str, ShapeTrips]] = None
5562
self.stop_to_route = None
5663
self.stop_to_coordinates = None
64+
self.trips_no_shapes_per_route: Dict[str, List[str]] = {}
5765

5866
self.logger.info("Using work directory: %s", self.workdir)
5967

@@ -90,41 +98,63 @@ def _read_csv(self, filename) -> list[dict]:
9098
except Exception as e:
9199
raise Exception(f"Failed to read CSV file {filename}: {e}") from e
92100

93-
def get_trip_from_route(self, route_id):
94-
if self.route_to_trip is None:
95-
self.route_to_trip = {}
96-
for row in self.get_file(TRIPS_FILE):
97-
route_id = row["route_id"]
98-
trip_id = row["trip_id"]
99-
if trip_id:
100-
self.route_to_trip.setdefault(route_id, trip_id)
101-
return self.route_to_trip.get(route_id, "")
102-
103-
def get_shape_from_route(self, route_id) -> str:
101+
def get_shape_from_route(self, route_id) -> Dict[str, List[ShapeTrips]]:
104102
"""
105-
Returns the first shape_id associated with a given route_id from the trips file.
103+
Returns a list of shape_ids with associated trip_ids information with a given route_id from the trips file.
106104
The relationship from the route to the shape is via the trips file.
107105
Parameters:
108106
route_id (str): The route identifier to look up.
109107
110108
Returns:
111109
The corresponding shape id.
110+
Example return value: [{'shape_id1': { 'shape_id': 'shape_id1', 'trip_ids': ['trip1', 'trip2']}},
111+
{'shape_id': 'shape_id2', 'trip_ids': ['trip3']}}]
112112
"""
113113
if self.route_to_shape is None:
114114
self.route_to_shape = {}
115115
for row in self.get_file(TRIPS_FILE):
116-
route_id = row["route_id"]
117-
shape_id = row["shape_id"]
118-
if shape_id:
119-
self.route_to_shape.setdefault(route_id, shape_id)
120-
return self.route_to_shape.get(route_id, "")
116+
route_id = get_safe_value(row, "route_id")
117+
shape_id = get_safe_value(row, "shape_id")
118+
trip_id = get_safe_value(row, "trip_id")
119+
if route_id and trip_id:
120+
if shape_id:
121+
route_shapes = self.route_to_shape.setdefault(route_id, {})
122+
shape_trips = route_shapes.setdefault(
123+
shape_id, {"shape_id": shape_id, "trip_ids": []}
124+
)
125+
shape_trips["trip_ids"].append(trip_id)
126+
else:
127+
# Registering the trip without a shape for this route for later retrieval.
128+
trip_no_shapes = (
129+
self.trips_no_shapes_per_route.get(route_id)
130+
if route_id in self.trips_no_shapes_per_route
131+
else None
132+
)
133+
if trip_no_shapes is None:
134+
trip_no_shapes = []
135+
self.trips_no_shapes_per_route[route_id] = trip_no_shapes
136+
trip_no_shapes.append(trip_id)
137+
return self.route_to_shape.get(route_id, {})
138+
139+
def get_trips_without_shape_from_route(self, route_id) -> List[str]:
140+
return self.trips_no_shapes_per_route.get(route_id, [])
121141

122142
def get_stops_from_trip(self, trip_id):
123-
# Lazy instantiation of the dictionary, because we may not need it al all if there is a shape.
124143
if self.trip_to_stops is None:
125144
self.trip_to_stops = {}
126145
for row in self.get_file(STOP_TIMES_FILE):
127-
self.trip_to_stops.setdefault(row["trip_id"], []).append(row["stop_id"])
146+
trip_id = get_safe_value(row, "trip_id")
147+
stop_id = get_safe_value(row, "stop_id")
148+
if trip_id and stop_id:
149+
trip_to_stops = (
150+
self.trip_to_stops.get(trip_id)
151+
if trip_id in self.trip_to_stops
152+
else None
153+
)
154+
if trip_to_stops is None:
155+
trip_to_stops = []
156+
self.trip_to_stops[trip_id] = trip_to_stops
157+
trip_to_stops.append(stop_id)
128158
return self.trip_to_stops.get(trip_id, [])
129159

130160
def get_coordinates_for_stop(self, stop_id) -> tuple[float, float] | None:

functions-python/pmtiles_builder/src/gtfs_stops_to_geojson.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from csv_cache import CsvCache, ROUTES_FILE, TRIPS_FILE, STOP_TIMES_FILE, STOPS_FILE
77
from gtfs import stop_txt_is_lat_log_required
88
from shared.helpers.runtime_metrics import track_metrics
9-
from shared.helpers.transform import get_safe_float
9+
from shared.helpers.transform import get_safe_float, get_safe_value
1010

1111
logger = logging.getLogger(__name__)
1212

@@ -15,7 +15,7 @@ def create_routes_map(routes_data):
1515
"""Creates a dictionary of routes from route data."""
1616
routes = {}
1717
for row in routes_data:
18-
route_id = row.get("route_id")
18+
route_id = get_safe_value(row, "route_id")
1919
if route_id:
2020
routes[route_id] = row
2121
return routes
@@ -26,16 +26,16 @@ def build_stop_to_routes(stop_times_data, trips_data):
2626
# Build trip_id -> route_id mapping
2727
trip_to_route = {}
2828
for row in trips_data:
29-
trip_id = row.get("trip_id")
30-
route_id = row.get("route_id")
29+
trip_id = get_safe_value(row, "trip_id")
30+
route_id = get_safe_value(row, "route_id")
3131
if trip_id and route_id:
3232
trip_to_route[trip_id] = route_id
3333

3434
# Build stop_id -> set of route_ids
3535
stop_to_routes = defaultdict(set)
3636
for row in stop_times_data:
37-
trip_id = row.get("trip_id")
38-
stop_id = row.get("stop_id")
37+
trip_id = get_safe_value(row, "trip_id")
38+
stop_id = get_safe_value(row, "stop_id")
3939
if trip_id and stop_id:
4040
route_id = trip_to_route.get(trip_id)
4141
if route_id:
@@ -92,13 +92,13 @@ def convert_stops_to_geojson(csv_cache: CsvCache, output_file):
9292
},
9393
"properties": {
9494
"stop_id": stop_id,
95-
"stop_code": row.get("stop_code", ""),
96-
"stop_name": row.get("stop_name", ""),
97-
"stop_desc": row.get("stop_desc", ""),
98-
"zone_id": row.get("zone_id", ""),
99-
"stop_url": row.get("stop_url", ""),
100-
"wheelchair_boarding": row.get("wheelchair_boarding", ""),
101-
"location_type": row.get("location_type", ""),
95+
"stop_code": get_safe_value(row, "stop_code", ""),
96+
"stop_name": get_safe_value(row, "stop_name", ""),
97+
"stop_desc": get_safe_value(row, "stop_desc", ""),
98+
"zone_id": get_safe_value(row, "zone_id", ""),
99+
"stop_url": get_safe_value(row, "stop_url", ""),
100+
"wheelchair_boarding": get_safe_value(row, "wheelchair_boarding", ""),
101+
"location_type": get_safe_value(row, "location_type", ""),
102102
"route_ids": route_ids,
103103
"route_colors": route_colors,
104104
},

0 commit comments

Comments
 (0)