1515#
1616import csv
1717import os
18+ from typing import TypedDict , List , Dict
19+
1820
1921from gtfs import stop_txt_is_lat_log_required
2022from shared .helpers .logger import get_logger
2830AGENCY_FILE = "agency.txt"
2931
3032
33+ class ShapeTrips (TypedDict ):
34+ shape_id : str
35+ trip_ids : List [str ]
36+
37+
3138class CsvCache :
3239 """
3340 CsvCache provides cached access to GTFS CSV files in a specified working directory.
@@ -49,11 +56,12 @@ def __init__(
4956 self .workdir = workdir
5057
5158 self .file_data = {}
52- self .trip_to_stops = None
59+ self .trip_to_stops : Dict [ str , List [ str ]] = None
5360 self .route_to_trip = None
54- self .route_to_shape = None
61+ self .route_to_shape : Dict [ str , Dict [ str , ShapeTrips ]] = None
5562 self .stop_to_route = None
5663 self .stop_to_coordinates = None
64+ self .trips_no_shapes_per_route : Dict [str , List [str ]] = {}
5765
5866 self .logger .info ("Using work directory: %s" , self .workdir )
5967
@@ -90,41 +98,63 @@ def _read_csv(self, filename) -> list[dict]:
9098 except Exception as e :
9199 raise Exception (f"Failed to read CSV file { filename } : { e } " ) from e
92100
93- def get_trip_from_route (self , route_id ):
94- if self .route_to_trip is None :
95- self .route_to_trip = {}
96- for row in self .get_file (TRIPS_FILE ):
97- route_id = row ["route_id" ]
98- trip_id = row ["trip_id" ]
99- if trip_id :
100- self .route_to_trip .setdefault (route_id , trip_id )
101- return self .route_to_trip .get (route_id , "" )
102-
103- def get_shape_from_route (self , route_id ) -> str :
101+ def get_shape_from_route (self , route_id ) -> Dict [str , List [ShapeTrips ]]:
104102 """
105- Returns the first shape_id associated with a given route_id from the trips file.
103+ Returns a list of shape_ids with associated trip_ids information with a given route_id from the trips file.
106104 The relationship from the route to the shape is via the trips file.
107105 Parameters:
108106 route_id (str): The route identifier to look up.
109107
110108 Returns:
111109 The corresponding shape id.
110+ Example return value: [{'shape_id1': { 'shape_id': 'shape_id1', 'trip_ids': ['trip1', 'trip2']}},
111+ {'shape_id': 'shape_id2', 'trip_ids': ['trip3']}}]
112112 """
113113 if self .route_to_shape is None :
114114 self .route_to_shape = {}
115115 for row in self .get_file (TRIPS_FILE ):
116- route_id = row ["route_id" ]
117- shape_id = row ["shape_id" ]
118- if shape_id :
119- self .route_to_shape .setdefault (route_id , shape_id )
120- return self .route_to_shape .get (route_id , "" )
116+ route_id = get_safe_value (row , "route_id" )
117+ shape_id = get_safe_value (row , "shape_id" )
118+ trip_id = get_safe_value (row , "trip_id" )
119+ if route_id and trip_id :
120+ if shape_id :
121+ route_shapes = self .route_to_shape .setdefault (route_id , {})
122+ shape_trips = route_shapes .setdefault (
123+ shape_id , {"shape_id" : shape_id , "trip_ids" : []}
124+ )
125+ shape_trips ["trip_ids" ].append (trip_id )
126+ else :
127+ # Registering the trip without a shape for this route for later retrieval.
128+ trip_no_shapes = (
129+ self .trips_no_shapes_per_route .get (route_id )
130+ if route_id in self .trips_no_shapes_per_route
131+ else None
132+ )
133+ if trip_no_shapes is None :
134+ trip_no_shapes = []
135+ self .trips_no_shapes_per_route [route_id ] = trip_no_shapes
136+ trip_no_shapes .append (trip_id )
137+ return self .route_to_shape .get (route_id , {})
138+
139+ def get_trips_without_shape_from_route (self , route_id ) -> List [str ]:
140+ return self .trips_no_shapes_per_route .get (route_id , [])
121141
122142 def get_stops_from_trip (self , trip_id ):
123- # Lazy instantiation of the dictionary, because we may not need it al all if there is a shape.
124143 if self .trip_to_stops is None :
125144 self .trip_to_stops = {}
126145 for row in self .get_file (STOP_TIMES_FILE ):
127- self .trip_to_stops .setdefault (row ["trip_id" ], []).append (row ["stop_id" ])
146+ trip_id = get_safe_value (row , "trip_id" )
147+ stop_id = get_safe_value (row , "stop_id" )
148+ if trip_id and stop_id :
149+ trip_to_stops = (
150+ self .trip_to_stops .get (trip_id )
151+ if trip_id in self .trip_to_stops
152+ else None
153+ )
154+ if trip_to_stops is None :
155+ trip_to_stops = []
156+ self .trip_to_stops [trip_id ] = trip_to_stops
157+ trip_to_stops .append (stop_id )
128158 return self .trip_to_stops .get (trip_id , [])
129159
130160 def get_coordinates_for_stop (self , stop_id ) -> tuple [float , float ] | None :
0 commit comments