Skip to content

Commit 73ea4bc

Browse files
authored
feat: tl scraping function (#847)
1 parent dc5700c commit 73ea4bc

File tree

18 files changed

+810
-1241
lines changed

18 files changed

+810
-1241
lines changed

functions-python/batch_process_dataset/src/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,14 +323,21 @@ def process_dataset(cloud_event: CloudEvent):
323323
dataset_file: DatasetFile = None
324324
error_message = None
325325
try:
326-
# Extract data from message
326+
# Extract data from message
327+
logging.info(f"Cloud Event: {cloud_event}")
327328
data = base64.b64decode(cloud_event.data["message"]["data"]).decode()
328329
json_payload = json.loads(data)
329330
logging.info(
330331
f"[{json_payload['feed_stable_id']}] JSON Payload: {json.dumps(json_payload)}"
331332
)
332333
stable_id = json_payload["feed_stable_id"]
333334
execution_id = json_payload["execution_id"]
335+
except Exception as e:
336+
error_message = f"[{stable_id}] Error parsing message: [{e}]"
337+
logging.error(error_message)
338+
logging.error(f"Function completed with error:{error_message}")
339+
return
340+
try:
334341
trace_service = DatasetTraceService()
335342

336343
trace = trace_service.get_by_execution_and_stable_ids(execution_id, stable_id)

functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def test_process_dataset_normal_execution(
394394
@patch("batch_process_dataset.src.main.Logger")
395395
@patch("batch_process_dataset.src.main.DatasetTraceService")
396396
@patch("batch_process_dataset.src.main.DatasetProcessor")
397-
def test_process_dataset_exception(
397+
def test_process_dataset_exception_caught(
398398
self, mock_dataset_processor, mock_dataset_trace, _
399399
):
400400
db_url = os.getenv("TEST_FEEDS_DATABASE_URL", default=default_db_url)
@@ -413,11 +413,7 @@ def test_process_dataset_exception(
413413
mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0
414414

415415
# Call the function
416-
try:
417-
process_dataset(cloud_event)
418-
assert False
419-
except AttributeError:
420-
assert True
416+
process_dataset(cloud_event)
421417

422418
@patch("batch_process_dataset.src.main.Logger")
423419
@patch("batch_process_dataset.src.main.DatasetTraceService")

functions-python/extract_location/src/reverse_geolocation/location_extractor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ def update_location(
207207
.filter(Gtfsfeed.stable_id == dataset.feed.stable_id)
208208
.one_or_none()
209209
)
210+
if gtfs_feed is None:
211+
logging.error(f"Feed {dataset.feed.stable_id} not found a GTFS feed.")
212+
raise Exception(f"Feed {dataset.feed.stable_id} not found a GTFS feed.")
210213

211214
for gtfs_rt_feed in gtfs_feed.gtfs_rt_feeds:
212215
logging.info(f"Updating GTFS-RT feed with stable ID {gtfs_rt_feed.stable_id}")

functions-python/feed_sync_dispatcher_transitland/src/main.py

Lines changed: 100 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414
# limitations under the License.
1515
#
1616

17-
import json
1817
import logging
1918
import os
2019
import random
2120
import time
22-
from dataclasses import dataclass, asdict
23-
from typing import Optional, List
21+
from typing import Optional
2422

2523
import functions_framework
2624
import pandas as pd
@@ -29,14 +27,15 @@
2927
from requests.exceptions import RequestException, HTTPError
3028
from sqlalchemy.orm import Session
3129

32-
from database_gen.sqlacodegen_models import Gtfsfeed
30+
from database_gen.sqlacodegen_models import Feed
3331
from helpers.feed_sync.feed_sync_common import FeedSyncProcessor, FeedSyncPayload
3432
from helpers.feed_sync.feed_sync_dispatcher import feed_sync_dispatcher
33+
from helpers.feed_sync.models import TransitFeedSyncPayload
3534
from helpers.logger import Logger
3635
from helpers.pub_sub import get_pubsub_client, get_execution_id
36+
from typing import Tuple, List
37+
from collections import defaultdict
3738

38-
# Logging configuration
39-
logging.basicConfig(level=logging.INFO)
4039

4140
# Environment variables
4241
PUBSUB_TOPIC_NAME = os.getenv("PUBSUB_TOPIC_NAME")
@@ -45,68 +44,66 @@
4544
TRANSITLAND_API_KEY = os.getenv("TRANSITLAND_API_KEY")
4645
TRANSITLAND_OPERATOR_URL = os.getenv("TRANSITLAND_OPERATOR_URL")
4746
TRANSITLAND_FEED_URL = os.getenv("TRANSITLAND_FEED_URL")
48-
spec = ["gtfs", "gtfs-rt"]
4947

5048
# session instance to reuse connections
5149
session = requests.Session()
5250

5351

54-
@dataclass
55-
class TransitFeedSyncPayload:
52+
def process_feed_urls(feed: dict, urls_in_db: List[str]) -> Tuple[List[str], List[str]]:
5653
"""
57-
Data class for transit feed sync payloads.
54+
Extracts the valid feed URLs and their corresponding entity types from the feed dictionary. If the same URL
55+
corresponds to multiple entity types, the types are concatenated with a comma.
5856
"""
57+
url_keys_to_types = {
58+
"static_current": "",
59+
"realtime_alerts": "sa",
60+
"realtime_trip_updates": "tu",
61+
"realtime_vehicle_positions": "vp",
62+
}
5963

60-
external_id: str
61-
feed_id: str
62-
feed_url: Optional[str] = None
63-
execution_id: Optional[str] = None
64-
spec: Optional[str] = None
65-
auth_info_url: Optional[str] = None
66-
auth_param_name: Optional[str] = None
67-
type: Optional[str] = None
68-
operator_name: Optional[str] = None
69-
country: Optional[str] = None
70-
state_province: Optional[str] = None
71-
city_name: Optional[str] = None
72-
source: Optional[str] = None
73-
payload_type: Optional[str] = None
64+
urls = feed.get("urls", {})
65+
url_to_entity_types = defaultdict(list)
7466

75-
def to_dict(self):
76-
return asdict(self)
67+
for key, entity_type in url_keys_to_types.items():
68+
if (url := urls.get(key)) and (url not in urls_in_db):
69+
if entity_type:
70+
logging.info(f"Found URL for entity type: {entity_type}")
71+
url_to_entity_types[url].append(entity_type)
7772

78-
def to_json(self):
79-
return json.dumps(self.to_dict())
73+
valid_urls = []
74+
entity_types = []
8075

76+
for url, types in url_to_entity_types.items():
77+
valid_urls.append(url)
78+
logging.info(f"URL = {url}, Entity types = {types}")
79+
entity_types.append(",".join(types))
8180

82-
class TransitFeedSyncProcessor(FeedSyncProcessor):
83-
def check_url_status(self, url: str) -> bool:
84-
"""
85-
Checks if a URL returns a valid response status code.
86-
"""
87-
try:
88-
logging.info(f"Checking URL: {url}")
89-
if url is None or len(url) == 0:
90-
logging.warning("URL is empty. Skipping check.")
91-
return False
92-
response = requests.head(url, timeout=25)
93-
logging.info(f"URL status code: {response.status_code}")
94-
return response.status_code < 400
95-
except requests.RequestException as e:
96-
logging.warning(f"Failed to reach {url}: {e}")
97-
return False
81+
return valid_urls, entity_types
9882

83+
84+
class TransitFeedSyncProcessor(FeedSyncProcessor):
9985
def process_sync(
100-
self, db_session: Optional[Session] = None, execution_id: Optional[str] = None
86+
self, db_session: Session, execution_id: Optional[str] = None
10187
) -> List[FeedSyncPayload]:
10288
"""
10389
Process data synchronously to fetch, extract, combine, filter and prepare payloads for publishing
10490
to a queue based on conditions related to the data retrieved from TransitLand API.
10591
"""
106-
feeds_data = self.get_data(
107-
TRANSITLAND_FEED_URL, TRANSITLAND_API_KEY, spec, session
92+
feeds_data_gtfs_rt = self.get_data(
93+
TRANSITLAND_FEED_URL, TRANSITLAND_API_KEY, "gtfs_rt", session
94+
)
95+
logging.info(
96+
"Fetched %s GTFS-RT feeds from TransitLand API",
97+
len(feeds_data_gtfs_rt["feeds"]),
98+
)
99+
100+
feeds_data_gtfs = self.get_data(
101+
TRANSITLAND_FEED_URL, TRANSITLAND_API_KEY, "gtfs", session
102+
)
103+
logging.info(
104+
"Fetched %s GTFS feeds from TransitLand API", len(feeds_data_gtfs["feeds"])
108105
)
109-
logging.info("Fetched %s feeds from TransitLand API", len(feeds_data["feeds"]))
106+
feeds_data = feeds_data_gtfs["feeds"] + feeds_data_gtfs_rt["feeds"]
110107

111108
operators_data = self.get_data(
112109
TRANSITLAND_OPERATOR_URL, TRANSITLAND_API_KEY, session=session
@@ -115,8 +112,10 @@ def process_sync(
115112
"Fetched %s operators from TransitLand API",
116113
len(operators_data["operators"]),
117114
)
118-
119-
feeds = self.extract_feeds_data(feeds_data)
115+
all_urls = set(
116+
[element[0] for element in db_session.query(Feed.producer_url).all()]
117+
)
118+
feeds = self.extract_feeds_data(feeds_data, all_urls)
120119
operators = self.extract_operators_data(operators_data)
121120

122121
# Converts operators and feeds to pandas DataFrames
@@ -135,16 +134,18 @@ def process_sync(
135134
# Filtered out rows where 'feed_url' is missing
136135
combined_df = combined_df[combined_df["feed_url"].notna()]
137136

138-
# Group by 'feed_id' and concatenate 'operator_name' while keeping first values of other columns
137+
# Group by 'stable_id' and concatenate 'operator_name' while keeping first values of other columns
139138
df_grouped = (
140-
combined_df.groupby("feed_id")
139+
combined_df.groupby("stable_id")
141140
.agg(
142141
{
143142
"operator_name": lambda x: ", ".join(x),
144143
"feeds_onestop_id": "first",
144+
"feed_id": "first",
145145
"feed_url": "first",
146146
"operator_feed_id": "first",
147147
"spec": "first",
148+
"entity_types": "first",
148149
"country": "first",
149150
"state_province": "first",
150151
"city_name": "first",
@@ -173,11 +174,6 @@ def process_sync(
173174
filtered_df = filtered_df.drop_duplicates(
174175
subset=["feed_url"]
175176
) # Drop duplicates
176-
filtered_df = filtered_df[filtered_df["feed_url"].apply(self.check_url_status)]
177-
logging.info(
178-
"Filtered out %s feeds with invalid URLs",
179-
len(df_grouped) - len(filtered_df),
180-
)
181177

182178
# Convert filtered DataFrame to dictionary format
183179
combined_data = filtered_df.to_dict(orient="records")
@@ -187,7 +183,7 @@ def process_sync(
187183
for data in combined_data:
188184
external_id = data["feeds_onestop_id"]
189185
feed_url = data["feed_url"]
190-
source = "TLD"
186+
source = "tld"
191187

192188
if not self.check_external_id(db_session, external_id, source):
193189
payload_type = "new"
@@ -201,6 +197,8 @@ def process_sync(
201197
# prepare payload
202198
payload = TransitFeedSyncPayload(
203199
external_id=external_id,
200+
stable_id=data["stable_id"],
201+
entity_types=data["entity_types"],
204202
feed_id=data["feed_id"],
205203
execution_id=execution_id,
206204
feed_url=feed_url,
@@ -212,7 +210,7 @@ def process_sync(
212210
country=data["country"],
213211
state_province=data["state_province"],
214212
city_name=data["city_name"],
215-
source="TLD",
213+
source="tld",
216214
payload_type=payload_type,
217215
)
218216
payloads.append(FeedSyncPayload(external_id=external_id, payload=payload))
@@ -277,25 +275,39 @@ def get_data(
277275
logging.info("Finished fetching data.")
278276
return all_data
279277

280-
def extract_feeds_data(self, feeds_data: dict) -> List[dict]:
278+
def extract_feeds_data(self, feeds_data: dict, urls_in_db: List[str]) -> List[dict]:
281279
"""
282280
This function extracts relevant data from the Transitland feeds endpoint containing feeds information.
283281
Returns a list of dictionaries representing each feed.
284282
"""
285283
feeds = []
286-
for feed in feeds_data["feeds"]:
287-
feed_url = feed["urls"].get("static_current")
288-
feeds.append(
289-
{
290-
"feed_id": feed["id"],
291-
"feed_url": feed_url,
292-
"spec": feed["spec"].lower(),
293-
"feeds_onestop_id": feed["onestop_id"],
294-
"auth_info_url": feed["authorization"].get("info_url"),
295-
"auth_param_name": feed["authorization"].get("param_name"),
296-
"type": feed["authorization"].get("type"),
297-
}
298-
)
284+
for feed in feeds_data:
285+
feed_urls, entity_types = process_feed_urls(feed, urls_in_db)
286+
logging.info("Feed %s has %s valid URL(s)", feed["id"], len(feed_urls))
287+
logging.info("Feed %s entity types: %s", feed["id"], entity_types)
288+
if len(feed_urls) == 0:
289+
logging.warning("Feed URL not found for feed %s", feed["id"])
290+
continue
291+
292+
for feed_url, entity_types in zip(feed_urls, entity_types):
293+
if entity_types is not None and len(entity_types) > 0:
294+
stable_id = f"{feed['id']}-{entity_types.replace(',', '_')}"
295+
else:
296+
stable_id = feed["id"]
297+
logging.info("Stable ID: %s", stable_id)
298+
feeds.append(
299+
{
300+
"feed_id": feed["id"],
301+
"stable_id": stable_id,
302+
"feed_url": feed_url,
303+
"entity_types": entity_types if len(entity_types) > 0 else None,
304+
"spec": feed["spec"].lower(),
305+
"feeds_onestop_id": feed["onestop_id"],
306+
"auth_info_url": feed["authorization"].get("info_url"),
307+
"auth_param_name": feed["authorization"].get("param_name"),
308+
"type": feed["authorization"].get("type"),
309+
}
310+
)
299311
return feeds
300312

301313
def extract_operators_data(self, operators_data: dict) -> List[dict]:
@@ -309,16 +321,15 @@ def extract_operators_data(self, operators_data: dict) -> List[dict]:
309321
places = operator["agencies"][0]["places"]
310322
place = places[1] if len(places) > 1 else places[0]
311323

312-
operator_data = {
313-
"operator_name": operator.get("name"),
314-
"operator_feed_id": operator["feeds"][0]["id"]
315-
if operator.get("feeds")
316-
else None,
317-
"country": place.get("adm0_name") if place else None,
318-
"state_province": place.get("adm1_name") if place else None,
319-
"city_name": place.get("city_name") if place else None,
320-
}
321-
operators.append(operator_data)
324+
for related_feed in operator.get("feeds", []):
325+
operator_data = {
326+
"operator_name": operator.get("name"),
327+
"operator_feed_id": related_feed["id"],
328+
"country": place.get("adm0_name") if place else None,
329+
"state_province": place.get("adm1_name") if place else None,
330+
"city_name": place.get("city_name") if place else None,
331+
}
332+
operators.append(operator_data)
322333
return operators
323334

324335
def check_external_id(
@@ -328,12 +339,12 @@ def check_external_id(
328339
Checks if the external_id exists in the public.externalid table for the given source.
329340
:param db_session: SQLAlchemy session
330341
:param external_id: The external_id (feeds_onestop_id) to check
331-
:param source: The source to filter by (e.g., 'TLD' for TransitLand)
342+
:param source: The source to filter by (e.g., 'tld' for TransitLand)
332343
:return: True if the feed exists, False otherwise
333344
"""
334345
results = (
335-
db_session.query(Gtfsfeed)
336-
.filter(Gtfsfeed.externalids.any(associated_id=external_id))
346+
db_session.query(Feed)
347+
.filter(Feed.externalids.any(associated_id=external_id))
337348
.all()
338349
)
339350
return results is not None and len(results) > 0
@@ -345,12 +356,12 @@ def get_mbd_feed_url(
345356
Retrieves the feed_url from the public.feed table in the mbd for the given external_id.
346357
:param db_session: SQLAlchemy session
347358
:param external_id: The external_id (feeds_onestop_id) from TransitLand
348-
:param source: The source to filter by (e.g., 'TLD' for TransitLand)
359+
:param source: The source to filter by (e.g., 'tld' for TransitLand)
349360
:return: feed_url in mbd if exists, otherwise None
350361
"""
351362
results = (
352-
db_session.query(Gtfsfeed)
353-
.filter(Gtfsfeed.externalids.any(associated_id=external_id))
363+
db_session.query(Feed)
364+
.filter(Feed.externalids.any(associated_id=external_id))
354365
.all()
355366
)
356367
return results[0].producer_url if results else None

0 commit comments

Comments
 (0)