Skip to content

Commit daf200a

Browse files
committed
Update unit tests. Fix get_stops_from_trip
1 parent b07334f commit daf200a

File tree

6 files changed

+105
-36
lines changed

6 files changed

+105
-36
lines changed

functions-python/helpers/tests/test_transform.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
to_float,
99
get_safe_value,
1010
get_safe_float,
11+
get_safe_int,
1112
)
1213

1314

@@ -154,3 +155,21 @@ def test_default_value(self):
154155
self.assertEqual(get_safe_float(row, "value", default_value=4.56), 4.56)
155156
row = {"value": None}
156157
self.assertEqual(get_safe_float(row, "value", default_value=7.89), 7.89)
158+
159+
160+
class TestGetSafeInt(unittest.TestCase):
161+
def test_valid_int(self):
162+
row = {"value": "42"}
163+
self.assertEqual(get_safe_int(row, "value"), 42)
164+
165+
def test_invalid_int(self):
166+
row = {"value": "abc"}
167+
self.assertIsNone(get_safe_int(row, "value"))
168+
169+
def test_missing_key(self):
170+
row = {}
171+
self.assertIsNone(get_safe_int(row, "value"))
172+
173+
def test_empty_string(self):
174+
row = {"value": ""}
175+
self.assertIsNone(get_safe_int(row, "value"))

functions-python/helpers/transform.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def to_float(value, default_value: Optional[float] = None) -> Optional[float]:
8989
return default_value
9090

9191

92-
def get_safe_value(row, column_name, default_value=None) -> Optional[str]:
92+
def get_safe_value(row, column_name, default_value: str = None) -> Optional[str]:
9393
"""
9494
Get a safe value from the row. If the value is missing or empty, return the default value.
9595
"""
@@ -105,12 +105,23 @@ def get_safe_value(row, column_name, default_value=None) -> Optional[str]:
105105
return f"{value}".strip()
106106

107107

108-
def get_safe_float(row, column_name, default_value=None) -> Optional[float]:
108+
def get_safe_float(row, column_name, default_value: float = None) -> Optional[float]:
109109
"""
110-
Get a safe float value from the row. If the value is missing or cannot be converted to float,
110+
Get a safe float value from the row. If the value is missing or cannot be converted to float.
111111
"""
112112
safe_value = get_safe_value(row, column_name)
113113
try:
114114
return float(safe_value)
115115
except (ValueError, TypeError):
116116
return default_value
117+
118+
119+
def get_safe_int(row, column_name, default_value: int = None) -> Optional[int]:
120+
"""
121+
Get a safe int value from the row. If the value is missing or cannot be converted to int.
122+
"""
123+
safe_value = get_safe_value(row, column_name)
124+
try:
125+
return int(safe_value)
126+
except (ValueError, TypeError):
127+
return default_value

functions-python/pmtiles_builder/src/csv_cache.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
AGENCY_FILE = "agency.txt"
3131

3232

33-
class Shapes(TypedDict):
33+
class ShapeTrips(TypedDict):
3434
shape_id: str
3535
trip_ids: List[str]
3636

@@ -56,12 +56,12 @@ def __init__(
5656
self.workdir = workdir
5757

5858
self.file_data = {}
59-
self.trip_to_stops = None
59+
self.trip_to_stops: Dict[str, List[str]] = None
6060
self.route_to_trip = None
61-
self.route_to_shape: Dict[str, Dict[str, List[str]]] = None
61+
self.route_to_shape: Dict[str, Dict[str, ShapeTrips]] = None
6262
self.stop_to_route = None
6363
self.stop_to_coordinates = None
64-
self.trips_no_shapes = Dict[str, List[str]]
64+
self.trips_no_shapes_per_route: Dict[str, List[str]] = {}
6565

6666
self.logger.info("Using work directory: %s", self.workdir)
6767

@@ -111,7 +111,7 @@ def get_trip_from_route(self, route_id):
111111
self.route_to_trip.setdefault(route_id, trip_id)
112112
return self.route_to_trip.get(route_id, "")
113113

114-
def get_shape_from_route(self, route_id) -> str:
114+
def get_shape_from_route(self, route_id) -> List[ShapeTrips]:
115115
"""
116116
Returns a list of shape_ids with associated trip_ids information with a given route_id from the trips file.
117117
The relationship from the route to the shape is via the trips file.
@@ -120,38 +120,63 @@ def get_shape_from_route(self, route_id) -> str:
120120
121121
Returns:
122122
The corresponding shape id.
123-
Example return value: [{'shape_id1': ['trip1', 'trip2']}, {'shape_id2': ['trip3']}]
123+
Example return value: [{'shape_id1': { 'shape_id': 'shape_id1', 'trip_ids': ['trip1', 'trip2']}},
124+
{'shape_id': 'shape_id2', 'trip_ids': ['trip3']}}]
124125
"""
125126
if self.route_to_shape is None:
126127
self.route_to_shape = {}
127128
for row in self.get_file(TRIPS_FILE):
128-
route_id = row["route_id"]
129-
shape_id = row["shape_id"]
130-
trip_id = row["trip_id"]
129+
route_id = get_safe_value(row, "route_id")
130+
shape_id = get_safe_value(row, "shape_id")
131+
trip_id = get_safe_value(row, "trip_id")
131132
if route_id and trip_id:
132133
if shape_id:
133-
route_shapes = self.route_to_shape.get(route_id, {})
134+
route_shapes = self.route_to_shape.get(route_id, None)
135+
if route_shapes is None:
136+
route_shapes = {}
137+
self.route_to_shape[route_id] = route_shapes
134138
if shape_id not in route_shapes:
135-
route_shapes[shape_id] = []
136-
route_shapes[shape_id].append(trip_id)
137-
self.route_to_shape[route_id] = route_shapes
139+
shape_trips = {"shape_id": shape_id, "trip_ids": []}
140+
route_shapes[shape_id] = shape_trips
141+
else:
142+
shape_trips = route_shapes[shape_id]
143+
shape_trips["trip_ids"].append(trip_id)
138144
else:
139-
trip_no_shapes = self.trip_to_stops.get(route_id, [])
145+
# Registering the trip without a shape for this route for later retrieval.
146+
trip_no_shapes = (
147+
self.trips_no_shapes_per_route.get(route_id)
148+
if route_id in self.trips_no_shapes_per_route
149+
else None
150+
)
151+
if trip_no_shapes is None:
152+
trip_no_shapes = []
153+
self.trips_no_shapes_per_route[route_id] = trip_no_shapes
140154
trip_no_shapes.append(trip_id)
141-
self.trips_no_shapes[route_id] = trip_no_shapes
142155
return self.route_to_shape.get(route_id, [])
143156

144157
def get_trips_without_shape_from_route(self, route_id) -> List[str]:
145158
return (
146-
self.trips_no_shapes[route_id] if route_id in self.trips_no_shapes else []
159+
self.trips_no_shapes_per_route[route_id]
160+
if route_id in self.trips_no_shapes_per_route
161+
else []
147162
)
148163

149164
def get_stops_from_trip(self, trip_id):
150-
# Lazy instantiation of the dictionary, because we may not need it al all if there is a shape.
151165
if self.trip_to_stops is None:
152166
self.trip_to_stops = {}
153167
for row in self.get_file(STOP_TIMES_FILE):
154-
self.trip_to_stops.setdefault(row["trip_id"], []).append(row["stop_id"])
168+
trip_id = get_safe_value(row, "trip_id")
169+
stop_id = get_safe_value(row, "stop_id")
170+
if trip_id and stop_id:
171+
trip_to_stops = (
172+
self.trip_to_stops.get(trip_id)
173+
if trip_id in self.trip_to_stops
174+
else None
175+
)
176+
if trip_to_stops is None:
177+
trip_to_stops = []
178+
self.trip_to_stops[trip_id] = trip_to_stops
179+
trip_to_stops.append(stop_id)
155180
return self.trip_to_stops.get(trip_id, [])
156181

157182
def get_coordinates_for_stop(self, stop_id) -> tuple[float, float] | None:

functions-python/pmtiles_builder/src/gtfs_stops_to_geojson.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def create_routes_map(routes_data):
1515
"""Creates a dictionary of routes from route data."""
1616
routes = {}
1717
for row in routes_data:
18-
route_id = row.get("route_id")
18+
route_id = get_safe_value(row, "route_id")
1919
if route_id:
2020
routes[route_id] = row
2121
return routes
@@ -26,16 +26,16 @@ def build_stop_to_routes(stop_times_data, trips_data):
2626
# Build trip_id -> route_id mapping
2727
trip_to_route = {}
2828
for row in trips_data:
29-
trip_id = row.get("trip_id")
30-
route_id = row.get("route_id")
29+
trip_id = get_safe_value(row, "trip_id")
30+
route_id = get_safe_value(row, "route_id")
3131
if trip_id and route_id:
3232
trip_to_route[trip_id] = route_id
3333

3434
# Build stop_id -> set of route_ids
3535
stop_to_routes = defaultdict(set)
3636
for row in stop_times_data:
37-
trip_id = row.get("trip_id")
38-
stop_id = row.get("stop_id")
37+
trip_id = get_safe_value(row, "trip_id")
38+
stop_id = get_safe_value(row, "stop_id")
3939
if trip_id and stop_id:
4040
route_id = trip_to_route.get(trip_id)
4141
if route_id:

functions-python/pmtiles_builder/src/main.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@
3939
STOPS_FILE,
4040
AGENCY_FILE,
4141
SHAPES_FILE,
42+
ShapeTrips,
4243
)
4344
from gtfs_stops_to_geojson import convert_stops_to_geojson
4445
from shared.database_gen.sqlacodegen_models import Gtfsdataset, Gtfsfeed
4546
from shared.helpers.logger import get_logger, init_logger
4647
from shared.helpers.runtime_metrics import track_metrics
4748
from shared.database.database import with_db_session
48-
from shared.helpers.transform import get_safe_value
49+
from shared.helpers.transform import get_safe_value, get_safe_float
4950

5051
init_logger()
5152

@@ -349,7 +350,12 @@ def _create_shapes_index(self) -> dict:
349350
if not line:
350351
break
351352
row = dict(zip(columns, next(csv.reader([line]))))
352-
sid = row["shape_id"]
353+
sid = get_safe_value(row, "shape_id")
354+
if not sid:
355+
self.logger.warning(
356+
"Missing shape_id at line %s, skipping.", row
357+
)
358+
continue
353359
shapes_index.setdefault(sid, []).append(pos)
354360
count += 1
355361
if count % 1000000 == 0:
@@ -371,11 +377,21 @@ def _get_shape_points(self, shape_id, index):
371377
f.seek(pos)
372378
line = f.readline()
373379
row = dict(zip(index["columns"], next(csv.reader([line]))))
380+
shape_pt_lon = get_safe_float(row, "shape_pt_lon")
381+
shape_pt_lat = get_safe_float(row, "shape_pt_lat")
382+
shape_pt_sequence = get_safe_float(row, "shape_pt_sequence", 0)
383+
if shape_pt_lon is None or shape_pt_lat is None:
384+
self.logger.warning(
385+
"Invalid coordinates for shape_id %s at position %d, skipping.",
386+
shape_id,
387+
pos,
388+
)
389+
continue
374390
points.append(
375391
(
376-
float(row["shape_pt_lon"]),
377-
float(row["shape_pt_lat"]),
378-
int(row["shape_pt_sequence"]),
392+
shape_pt_lon,
393+
shape_pt_lat,
394+
shape_pt_sequence,
379395
)
380396
)
381397
points.sort(key=lambda x: x[2])
@@ -444,7 +460,7 @@ def _create_routes_geojson(self):
444460
"Processed route %d (route_id: %s)", i, route_id
445461
)
446462

447-
# geojson_file.write("\n]}")
463+
geojson_file.write("\n]}")
448464

449465
if missing_coordinates_routes:
450466
self.logger.info(
@@ -457,7 +473,7 @@ def _create_routes_geojson(self):
457473
raise Exception(f"Failed to create routes GeoJSON: {e}") from e
458474

459475
def get_route_coordinates(self, route_id, shapes_index) -> List[RouteCoordinates]:
460-
shapes: Dict[str, List[str]] = self.csv_cache.get_shape_from_route(route_id)
476+
shapes: Dict[str, ShapeTrips] = self.csv_cache.get_shape_from_route(route_id)
461477
result: List[RouteCoordinates] = []
462478
if shapes:
463479
for shape_id, trip_ids in shapes.items():
@@ -545,7 +561,6 @@ def _create_routes_json(self):
545561
routes = []
546562
for row in self.csv_cache.get_file(ROUTES_FILE):
547563
route_id = get_safe_value(row, "route_id", "")
548-
shape_ids = self.csv_cache.get_shape_from_route(route_id)
549564
route = {
550565
"routeId": route_id,
551566
"routeName": get_safe_value(row, "route_long_name", "")
@@ -554,7 +569,6 @@ def _create_routes_json(self):
554569
"color": f"#{get_safe_value(row, 'route_color', '000000')}",
555570
"textColor": f"#{get_safe_value(row, 'route_text_color', 'FFFFFF')}",
556571
"routeType": f"{get_safe_value(row, 'route_type', '')}",
557-
"shapes_ids": shape_ids,
558572
}
559573
routes.append(route)
560574

functions-python/pmtiles_builder/src/scripts/pmtiles_builder_verifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def download_feed_files(feed_stable_id: str, dataset_stable_id: str):
3434
logging.info(f"Downloading {url}")
3535
filename = f"{dataset_stable_id}/extracted/{file}"
3636
try:
37-
download_to_local(feed_stable_id, url, filename, True)
37+
download_to_local(feed_stable_id, url, filename, False)
3838
except Exception as e:
3939
logging.warning(f"Failed to download {file}: {e}")
4040

0 commit comments

Comments
 (0)