1414# limitations under the License.
1515#
1616
17- import json
1817import logging
1918import os
2019import random
2120import time
22- from dataclasses import dataclass , asdict
23- from typing import Optional , List
21+ from typing import Optional
2422
2523import functions_framework
2624import pandas as pd
2927from requests .exceptions import RequestException , HTTPError
3028from sqlalchemy .orm import Session
3129
32- from database_gen .sqlacodegen_models import Gtfsfeed
30+ from database_gen .sqlacodegen_models import Feed
3331from helpers .feed_sync .feed_sync_common import FeedSyncProcessor , FeedSyncPayload
3432from helpers .feed_sync .feed_sync_dispatcher import feed_sync_dispatcher
33+ from helpers .feed_sync .models import TransitFeedSyncPayload
3534from helpers .logger import Logger
3635from helpers .pub_sub import get_pubsub_client , get_execution_id
36+ from typing import Tuple , List
37+ from collections import defaultdict
3738
38- # Logging configuration
39- logging .basicConfig (level = logging .INFO )
4039
4140# Environment variables
4241PUBSUB_TOPIC_NAME = os .getenv ("PUBSUB_TOPIC_NAME" )
4544TRANSITLAND_API_KEY = os .getenv ("TRANSITLAND_API_KEY" )
4645TRANSITLAND_OPERATOR_URL = os .getenv ("TRANSITLAND_OPERATOR_URL" )
4746TRANSITLAND_FEED_URL = os .getenv ("TRANSITLAND_FEED_URL" )
48- spec = ["gtfs" , "gtfs-rt" ]
4947
5048# session instance to reuse connections
5149session = requests .Session ()
5250
5351
54- @dataclass
55- class TransitFeedSyncPayload :
52+ def process_feed_urls (feed : dict , urls_in_db : List [str ]) -> Tuple [List [str ], List [str ]]:
5653 """
57- Data class for transit feed sync payloads.
54+ Extracts the valid feed URLs and their corresponding entity types from the feed dictionary. If the same URL
55+ corresponds to multiple entity types, the types are concatenated with a comma.
5856 """
57+ url_keys_to_types = {
58+ "static_current" : "" ,
59+ "realtime_alerts" : "sa" ,
60+ "realtime_trip_updates" : "tu" ,
61+ "realtime_vehicle_positions" : "vp" ,
62+ }
5963
60- external_id : str
61- feed_id : str
62- feed_url : Optional [str ] = None
63- execution_id : Optional [str ] = None
64- spec : Optional [str ] = None
65- auth_info_url : Optional [str ] = None
66- auth_param_name : Optional [str ] = None
67- type : Optional [str ] = None
68- operator_name : Optional [str ] = None
69- country : Optional [str ] = None
70- state_province : Optional [str ] = None
71- city_name : Optional [str ] = None
72- source : Optional [str ] = None
73- payload_type : Optional [str ] = None
64+ urls = feed .get ("urls" , {})
65+ url_to_entity_types = defaultdict (list )
7466
75- def to_dict (self ):
76- return asdict (self )
67+ for key , entity_type in url_keys_to_types .items ():
68+ if (url := urls .get (key )) and (url not in urls_in_db ):
69+ if entity_type :
70+ logging .info (f"Found URL for entity type: { entity_type } " )
71+ url_to_entity_types [url ].append (entity_type )
7772
78- def to_json ( self ):
79- return json . dumps ( self . to_dict ())
73+ valid_urls = []
74+ entity_types = []
8075
76+ for url , types in url_to_entity_types .items ():
77+ valid_urls .append (url )
78+ logging .info (f"URL = { url } , Entity types = { types } " )
79+ entity_types .append ("," .join (types ))
8180
82- class TransitFeedSyncProcessor (FeedSyncProcessor ):
83- def check_url_status (self , url : str ) -> bool :
84- """
85- Checks if a URL returns a valid response status code.
86- """
87- try :
88- logging .info (f"Checking URL: { url } " )
89- if url is None or len (url ) == 0 :
90- logging .warning ("URL is empty. Skipping check." )
91- return False
92- response = requests .head (url , timeout = 25 )
93- logging .info (f"URL status code: { response .status_code } " )
94- return response .status_code < 400
95- except requests .RequestException as e :
96- logging .warning (f"Failed to reach { url } : { e } " )
97- return False
81+ return valid_urls , entity_types
9882
83+
84+ class TransitFeedSyncProcessor (FeedSyncProcessor ):
9985 def process_sync (
100- self , db_session : Optional [ Session ] = None , execution_id : Optional [str ] = None
86+ self , db_session : Session , execution_id : Optional [str ] = None
10187 ) -> List [FeedSyncPayload ]:
10288 """
10389 Process data synchronously to fetch, extract, combine, filter and prepare payloads for publishing
10490 to a queue based on conditions related to the data retrieved from TransitLand API.
10591 """
106- feeds_data = self .get_data (
107- TRANSITLAND_FEED_URL , TRANSITLAND_API_KEY , spec , session
92+ feeds_data_gtfs_rt = self .get_data (
93+ TRANSITLAND_FEED_URL , TRANSITLAND_API_KEY , "gtfs_rt" , session
94+ )
95+ logging .info (
96+ "Fetched %s GTFS-RT feeds from TransitLand API" ,
97+ len (feeds_data_gtfs_rt ["feeds" ]),
98+ )
99+
100+ feeds_data_gtfs = self .get_data (
101+ TRANSITLAND_FEED_URL , TRANSITLAND_API_KEY , "gtfs" , session
102+ )
103+ logging .info (
104+ "Fetched %s GTFS feeds from TransitLand API" , len (feeds_data_gtfs ["feeds" ])
108105 )
109- logging . info ( "Fetched %s feeds from TransitLand API" , len ( feeds_data ["feeds" ]))
106+ feeds_data = feeds_data_gtfs [ " feeds" ] + feeds_data_gtfs_rt ["feeds" ]
110107
111108 operators_data = self .get_data (
112109 TRANSITLAND_OPERATOR_URL , TRANSITLAND_API_KEY , session = session
@@ -115,8 +112,10 @@ def process_sync(
115112 "Fetched %s operators from TransitLand API" ,
116113 len (operators_data ["operators" ]),
117114 )
118-
119- feeds = self .extract_feeds_data (feeds_data )
115+ all_urls = set (
116+ [element [0 ] for element in db_session .query (Feed .producer_url ).all ()]
117+ )
118+ feeds = self .extract_feeds_data (feeds_data , all_urls )
120119 operators = self .extract_operators_data (operators_data )
121120
122121 # Converts operators and feeds to pandas DataFrames
@@ -135,16 +134,18 @@ def process_sync(
135134 # Filtered out rows where 'feed_url' is missing
136135 combined_df = combined_df [combined_df ["feed_url" ].notna ()]
137136
138- # Group by 'feed_id ' and concatenate 'operator_name' while keeping first values of other columns
137+ # Group by 'stable_id ' and concatenate 'operator_name' while keeping first values of other columns
139138 df_grouped = (
140- combined_df .groupby ("feed_id " )
139+ combined_df .groupby ("stable_id " )
141140 .agg (
142141 {
143142 "operator_name" : lambda x : ", " .join (x ),
144143 "feeds_onestop_id" : "first" ,
144+ "feed_id" : "first" ,
145145 "feed_url" : "first" ,
146146 "operator_feed_id" : "first" ,
147147 "spec" : "first" ,
148+ "entity_types" : "first" ,
148149 "country" : "first" ,
149150 "state_province" : "first" ,
150151 "city_name" : "first" ,
@@ -173,11 +174,6 @@ def process_sync(
173174 filtered_df = filtered_df .drop_duplicates (
174175 subset = ["feed_url" ]
175176 ) # Drop duplicates
176- filtered_df = filtered_df [filtered_df ["feed_url" ].apply (self .check_url_status )]
177- logging .info (
178- "Filtered out %s feeds with invalid URLs" ,
179- len (df_grouped ) - len (filtered_df ),
180- )
181177
182178 # Convert filtered DataFrame to dictionary format
183179 combined_data = filtered_df .to_dict (orient = "records" )
@@ -187,7 +183,7 @@ def process_sync(
187183 for data in combined_data :
188184 external_id = data ["feeds_onestop_id" ]
189185 feed_url = data ["feed_url" ]
190- source = "TLD "
186+ source = "tld "
191187
192188 if not self .check_external_id (db_session , external_id , source ):
193189 payload_type = "new"
@@ -201,6 +197,8 @@ def process_sync(
201197 # prepare payload
202198 payload = TransitFeedSyncPayload (
203199 external_id = external_id ,
200+ stable_id = data ["stable_id" ],
201+ entity_types = data ["entity_types" ],
204202 feed_id = data ["feed_id" ],
205203 execution_id = execution_id ,
206204 feed_url = feed_url ,
@@ -212,7 +210,7 @@ def process_sync(
212210 country = data ["country" ],
213211 state_province = data ["state_province" ],
214212 city_name = data ["city_name" ],
215- source = "TLD " ,
213+ source = "tld " ,
216214 payload_type = payload_type ,
217215 )
218216 payloads .append (FeedSyncPayload (external_id = external_id , payload = payload ))
@@ -277,25 +275,39 @@ def get_data(
277275 logging .info ("Finished fetching data." )
278276 return all_data
279277
280- def extract_feeds_data (self , feeds_data : dict ) -> List [dict ]:
278+ def extract_feeds_data (self , feeds_data : dict , urls_in_db : List [ str ] ) -> List [dict ]:
281279 """
282280 This function extracts relevant data from the Transitland feeds endpoint containing feeds information.
283281 Returns a list of dictionaries representing each feed.
284282 """
285283 feeds = []
286- for feed in feeds_data ["feeds" ]:
287- feed_url = feed ["urls" ].get ("static_current" )
288- feeds .append (
289- {
290- "feed_id" : feed ["id" ],
291- "feed_url" : feed_url ,
292- "spec" : feed ["spec" ].lower (),
293- "feeds_onestop_id" : feed ["onestop_id" ],
294- "auth_info_url" : feed ["authorization" ].get ("info_url" ),
295- "auth_param_name" : feed ["authorization" ].get ("param_name" ),
296- "type" : feed ["authorization" ].get ("type" ),
297- }
298- )
284+ for feed in feeds_data :
285+ feed_urls , entity_types = process_feed_urls (feed , urls_in_db )
286+ logging .info ("Feed %s has %s valid URL(s)" , feed ["id" ], len (feed_urls ))
287+ logging .info ("Feed %s entity types: %s" , feed ["id" ], entity_types )
288+ if len (feed_urls ) == 0 :
289+ logging .warning ("Feed URL not found for feed %s" , feed ["id" ])
290+ continue
291+
292+ for feed_url , entity_types in zip (feed_urls , entity_types ):
293+ if entity_types is not None and len (entity_types ) > 0 :
294+ stable_id = f"{ feed ['id' ]} -{ entity_types .replace (',' , '_' )} "
295+ else :
296+ stable_id = feed ["id" ]
297+ logging .info ("Stable ID: %s" , stable_id )
298+ feeds .append (
299+ {
300+ "feed_id" : feed ["id" ],
301+ "stable_id" : stable_id ,
302+ "feed_url" : feed_url ,
303+ "entity_types" : entity_types if len (entity_types ) > 0 else None ,
304+ "spec" : feed ["spec" ].lower (),
305+ "feeds_onestop_id" : feed ["onestop_id" ],
306+ "auth_info_url" : feed ["authorization" ].get ("info_url" ),
307+ "auth_param_name" : feed ["authorization" ].get ("param_name" ),
308+ "type" : feed ["authorization" ].get ("type" ),
309+ }
310+ )
299311 return feeds
300312
301313 def extract_operators_data (self , operators_data : dict ) -> List [dict ]:
@@ -309,16 +321,15 @@ def extract_operators_data(self, operators_data: dict) -> List[dict]:
309321 places = operator ["agencies" ][0 ]["places" ]
310322 place = places [1 ] if len (places ) > 1 else places [0 ]
311323
312- operator_data = {
313- "operator_name" : operator .get ("name" ),
314- "operator_feed_id" : operator ["feeds" ][0 ]["id" ]
315- if operator .get ("feeds" )
316- else None ,
317- "country" : place .get ("adm0_name" ) if place else None ,
318- "state_province" : place .get ("adm1_name" ) if place else None ,
319- "city_name" : place .get ("city_name" ) if place else None ,
320- }
321- operators .append (operator_data )
324+ for related_feed in operator .get ("feeds" , []):
325+ operator_data = {
326+ "operator_name" : operator .get ("name" ),
327+ "operator_feed_id" : related_feed ["id" ],
328+ "country" : place .get ("adm0_name" ) if place else None ,
329+ "state_province" : place .get ("adm1_name" ) if place else None ,
330+ "city_name" : place .get ("city_name" ) if place else None ,
331+ }
332+ operators .append (operator_data )
322333 return operators
323334
324335 def check_external_id (
@@ -328,12 +339,12 @@ def check_external_id(
328339 Checks if the external_id exists in the public.externalid table for the given source.
329340 :param db_session: SQLAlchemy session
330341 :param external_id: The external_id (feeds_onestop_id) to check
331- :param source: The source to filter by (e.g., 'TLD ' for TransitLand)
342+ :param source: The source to filter by (e.g., 'tld ' for TransitLand)
332343 :return: True if the feed exists, False otherwise
333344 """
334345 results = (
335- db_session .query (Gtfsfeed )
336- .filter (Gtfsfeed .externalids .any (associated_id = external_id ))
346+ db_session .query (Feed )
347+ .filter (Feed .externalids .any (associated_id = external_id ))
337348 .all ()
338349 )
339350 return results is not None and len (results ) > 0
@@ -345,12 +356,12 @@ def get_mbd_feed_url(
345356 Retrieves the feed_url from the public.feed table in the mbd for the given external_id.
346357 :param db_session: SQLAlchemy session
347358 :param external_id: The external_id (feeds_onestop_id) from TransitLand
348- :param source: The source to filter by (e.g., 'TLD ' for TransitLand)
359+ :param source: The source to filter by (e.g., 'tld ' for TransitLand)
349360 :return: feed_url in mbd if exists, otherwise None
350361 """
351362 results = (
352- db_session .query (Gtfsfeed )
353- .filter (Gtfsfeed .externalids .any (associated_id = external_id ))
363+ db_session .query (Feed )
364+ .filter (Feed .externalids .any (associated_id = external_id ))
354365 .all ()
355366 )
356367 return results [0 ].producer_url if results else None
0 commit comments