2626)
2727from parse_request import parse_request_parameters
2828from shared .common .gcp_utils import create_refresh_materialized_view_task
29- from shared .database .database import with_db_session
29+ from shared .database .database import with_db_session , get_db_timestamp
3030from shared .database_gen .sqlacodegen_models import (
3131 Feed ,
3232 Feedlocationgrouppoint ,
3333 Osmlocationgroup ,
3434 Gtfsdataset ,
3535 Gtfsfeed ,
36+ Gbfsfeed ,
3637)
3738from shared .dataset_service .dataset_service_commons import Status
3839
@@ -134,15 +135,18 @@ def clean_stop_cache(db_session, feed, geometries_to_delete, logger):
134135 db_session .commit ()
135136
136137
138+ @with_db_session
137139def create_geojson_aggregate (
138140 location_groups : List [GeopolygonAggregate ],
139141 total_stops : int ,
140- stable_id : str ,
141142 bounding_box : shapely .Polygon ,
142143 data_type : str ,
143144 logger ,
145+ feed : Gtfsfeed | Gbfsfeed ,
146+ gtfs_dataset : Gtfsdataset = None ,
144147 extraction_urls : List [str ] = None ,
145148 public : bool = True ,
149+ db_session : Session = None ,
146150) -> None :
147151 """Create a GeoJSON file with the aggregated locations. This file will be uploaded to GCS and used for
148152 visualization."""
@@ -197,10 +201,13 @@ def create_geojson_aggregate(
197201 else :
198202 raise ValueError ("The data type must be either 'gtfs' or 'gbfs'." )
199203 bucket = storage_client .bucket (bucket_name )
200- blob = bucket .blob (f"{ stable_id } /geolocation.geojson" )
204+ blob = bucket .blob (f"{ feed . stable_id } /geolocation.geojson" )
201205 blob .upload_from_string (json .dumps (json_data ))
202206 if public :
203207 blob .make_public ()
208+ feed .geolocation_file_created_date = get_db_timestamp (db_session )
209+ if gtfs_dataset :
210+ feed .geolocation_file_dataset = gtfs_dataset
204211 logger .info ("GeoJSON data saved to %s" , blob .name )
205212
206213
@@ -210,10 +217,9 @@ def get_storage_client():
210217 return storage .Client ()
211218
212219
213- @with_db_session
214220@track_metrics (metrics = ("time" , "memory" , "cpu" ))
215221def update_dataset_bounding_box (
216- dataset_id : str , stops_df : pd .DataFrame , db_session : Session
222+ gtfs_dataset : Gtfsdataset , stops_df : pd .DataFrame , db_session : Session
217223) -> shapely .Polygon :
218224 """
219225 Update the bounding box of the dataset using the stops DataFrame.
@@ -231,19 +237,12 @@ def update_dataset_bounding_box(
231237 f")" ,
232238 srid = 4326 ,
233239 )
234- if not dataset_id :
235- return to_shape (bounding_box )
236- gtfs_dataset = (
237- db_session .query (Gtfsdataset )
238- .filter (Gtfsdataset .stable_id == dataset_id )
239- .one_or_none ()
240- )
241240 if not gtfs_dataset :
242- raise ValueError ( f"Dataset { dataset_id } does not exist in the database." )
241+ return to_shape ( bounding_box )
243242 gtfs_feed = db_session .get (Gtfsfeed , gtfs_dataset .feed_id )
244243 if not gtfs_feed :
245244 raise ValueError (
246- f"GTFS feed for dataset { dataset_id } does not exist in the database."
245+ f"GTFS feed for dataset { gtfs_dataset . stable_id } does not exist in the database."
247246 )
248247 gtfs_feed .bounding_box = bounding_box
249248 gtfs_feed .bounding_box_dataset = gtfs_dataset
@@ -252,8 +251,22 @@ def update_dataset_bounding_box(
252251 return to_shape (bounding_box )
253252
254253
254+ def load_dataset (dataset_id : str , db_session : Session ) -> Gtfsdataset :
255+ gtfs_dataset = (
256+ db_session .query (Gtfsdataset )
257+ .filter (Gtfsdataset .stable_id == dataset_id )
258+ .one_or_none ()
259+ )
260+ if not gtfs_dataset :
261+ raise ValueError (
262+ f"Dataset with ID { dataset_id } does not exist in the database."
263+ )
264+ return gtfs_dataset
265+
266+
267+ @with_db_session ()
255268def reverse_geolocation_process (
256- request : flask .Request ,
269+ request : flask .Request , db_session : Session = None
257270) -> Tuple [str , int ] | Tuple [Dict , int ]:
258271 """
259272 Main function to handle reverse geolocation processing.
@@ -331,14 +344,21 @@ def reverse_geolocation_process(
331344
332345 try :
333346 # Update the bounding box of the dataset
334- bounding_box = update_dataset_bounding_box (dataset_id , stops_df )
347+ gtfs_dataset : Gtfsdataset = None
348+ if dataset_id :
349+ gtfs_dataset = load_dataset (dataset_id , db_session )
350+ feed = load_feed (stable_id , data_type , logger , db_session )
351+
352+ bounding_box = update_dataset_bounding_box (gtfs_dataset , stops_df , db_session )
335353
336354 location_groups = reverse_geolocation (
337355 strategy = strategy ,
338356 stable_id = stable_id ,
339357 stops_df = stops_df ,
358+ data_type = data_type ,
340359 logger = logger ,
341360 use_cache = use_cache ,
361+ db_session = db_session ,
342362 )
343363
344364 if not location_groups :
@@ -358,14 +378,19 @@ def reverse_geolocation_process(
358378 create_geojson_aggregate (
359379 list (location_groups .values ()),
360380 total_stops = total_stops ,
361- stable_id = stable_id ,
362381 bounding_box = bounding_box ,
363382 data_type = data_type ,
364383 extraction_urls = extraction_urls ,
365384 logger = logger ,
366385 public = public ,
386+ feed = feed ,
387+ gtfs_dataset = gtfs_dataset ,
388+ db_session = db_session ,
367389 )
368390
391+ # Commit the changes to the database
392+ db_session .commit ()
393+ create_refresh_materialized_view_task ()
369394 logger .info (
370395 "COMPLETED. Processed %s stops for stable ID %s with strategy. "
371396 "Retrieved %s locations." ,
@@ -408,6 +433,7 @@ def reverse_geolocation(
408433 strategy ,
409434 stable_id ,
410435 stops_df ,
436+ data_type ,
411437 logger ,
412438 use_cache ,
413439 db_session : Session = None ,
@@ -417,7 +443,7 @@ def reverse_geolocation(
417443 """
418444 logger .info ("Processing geopolygons with strategy: %s." , strategy )
419445
420- feed = load_feed (stable_id , logger , db_session )
446+ feed = load_feed (stable_id , data_type , logger , db_session )
421447
422448 # Get Geopolygons with Geometry and cached location groups
423449 cache_location_groups , unmatched_stops_df = get_geopolygons_with_geometry (
@@ -453,13 +479,13 @@ def reverse_geolocation(
453479 logger = logger ,
454480 db_session = db_session ,
455481 )
456- create_refresh_materialized_view_task ()
457482 return cache_location_groups
458483
459484
460- def load_feed (stable_id , logger , db_session ):
485+ def load_feed (stable_id , data_type , logger , db_session ) -> Gtfsfeed | Gbfsfeed :
486+ """Load feed from the database using the stable ID and data type."""
461487 feed = (
462- db_session .query (Feed )
488+ db_session .query (Gbfsfeed if data_type == "gbfs" else Gtfsfeed )
463489 .options (joinedload (Feed .feedlocationgrouppoints ))
464490 .filter (Feed .stable_id == stable_id )
465491 .one_or_none ()
@@ -508,5 +534,3 @@ def update_feed_location(
508534 gtfs_rt_feed .locations = feed_locations
509535 if feed_locations :
510536 feed .locations = feed_locations
511- # Commit the changes to the database
512- db_session .commit ()
0 commit comments