Skip to content

Commit 67f7715

Browse files
committed
updated populate script
1 parent a4bad58 commit 67f7715

File tree

2 files changed

+73
-18
lines changed

2 files changed

+73
-18
lines changed

api/src/scripts/populate_db_gtfs.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pycountry
66
import pytz
7+
from sqlalchemy import func
78

89
from scripts.load_dataset_on_create import publish_all
910
from scripts.populate_db import DatabasePopulateHelper, set_up_configs
@@ -14,6 +15,7 @@
1415
Gtfsrealtimefeed,
1516
Location,
1617
Redirectingid,
18+
Gtfsfeed,
1719
)
1820
from utils.data_utils import set_up_defaults
1921

@@ -116,32 +118,95 @@ def process_entity_types(self, session: "Session", feed: Gtfsrealtimefeed, row,
116118
self.logger.warning(f"Entity types array is empty for feed {stable_id}")
117119
feed.entitytypes.clear()
118120

121+
# def process_feed_references(self, session: "Session"):
122+
# """
123+
# Process the feed references
124+
# """
125+
# self.logger.info("Processing feed references")
126+
# for index, row in self.df.iterrows():
127+
# stable_id = self.get_stable_id(row)
128+
# data_type = self.get_data_type(row)
129+
# if data_type != "gtfs_rt":
130+
# continue
131+
# gtfs_rt_feed = self.query_feed_by_stable_id(session, stable_id, "gtfs_rt")
132+
# static_reference = self.get_safe_value(row, "static_reference", "")
133+
# if static_reference:
134+
# try:
135+
# gtfs_stable_id = f"mdb-{int(float(static_reference))}"
136+
# except ValueError:
137+
# gtfs_stable_id = static_reference
138+
# gtfs_feed = self.query_feed_by_stable_id(session, gtfs_stable_id, "gtfs")
139+
# if not gtfs_feed:
140+
# self.logger.warning(f"Could not find static reference feed {gtfs_stable_id} for feed {stable_id}")
141+
# continue
142+
# already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds}
143+
# if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids:
144+
# gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed)
145+
# # Flush to avoid FK violation
146+
# session.flush()
147+
119148
def process_feed_references(self, session: "Session"):
120149
"""
121-
Process the feed references
150+
Process the feed references for GTFS-RT feeds.
151+
152+
1. Uses 'static_reference' column if present.
153+
2. Falls back to matching static feeds by provider name.
122154
"""
123155
self.logger.info("Processing feed references")
156+
124157
for index, row in self.df.iterrows():
125158
stable_id = self.get_stable_id(row)
126159
data_type = self.get_data_type(row)
160+
161+
# Only process GTFS-RT feeds
127162
if data_type != "gtfs_rt":
128163
continue
164+
129165
gtfs_rt_feed = self.query_feed_by_stable_id(session, stable_id, "gtfs_rt")
130-
static_reference = self.get_safe_value(row, "static_reference", "")
166+
if not gtfs_rt_feed:
167+
self.logger.warning(f"Could not find GTFS-RT feed {stable_id}")
168+
continue
169+
170+
# Try static_reference column first
171+
static_reference = self.get_safe_value(row, "static_reference", "").strip()
172+
gtfs_feed = None
173+
131174
if static_reference:
175+
# Normalize stable_id
132176
try:
133177
gtfs_stable_id = f"mdb-{int(float(static_reference))}"
134178
except ValueError:
135179
gtfs_stable_id = static_reference
180+
136181
gtfs_feed = self.query_feed_by_stable_id(session, gtfs_stable_id, "gtfs")
137182
if not gtfs_feed:
138183
self.logger.warning(f"Could not find static reference feed {gtfs_stable_id} for feed {stable_id}")
139-
continue
184+
185+
# Fallback: match by provider if no static_reference or not found
186+
if not gtfs_feed:
187+
provider_value = (self.get_safe_value(row, "provider", "") or "").strip().lower()
188+
if provider_value:
189+
gtfs_feed = (
190+
session.query(Gtfsfeed)
191+
.filter(
192+
Gtfsfeed.data_type == "gtfs",
193+
func.lower(func.trim(Gtfsfeed.provider)) == provider_value,
194+
Gtfsfeed.stable_id != stable_id,
195+
)
196+
.first()
197+
)
198+
if not gtfs_feed:
199+
self.logger.warning(
200+
f"No static GTFS feed found for provider '{provider_value}' for feed {stable_id}"
201+
)
202+
203+
# Link the feeds if we have a valid static GTFS feed
204+
if gtfs_feed:
140205
already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds}
141-
if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids:
206+
if gtfs_rt_feed.id not in already_referenced_ids:
142207
gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed)
143-
# Flush to avoid FK violation
144-
session.flush()
208+
session.flush() # Avoid FK violations
209+
self.logger.info(f"Linked GTFS-RT feed {stable_id} to static feed {gtfs_feed.stable_id}")
145210

146211
def process_redirects(self, session: "Session"):
147212
"""

api/src/shared/db_models/gtfs_rt_feed_impl.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,13 @@ class Config:
2020
from_attributes = True
2121

2222
@classmethod
23-
@with_db_session
24-
def from_orm(cls, feed: Gtfsrealtimefeed | None, db_session: Session) -> GtfsRTFeed | None:
23+
def from_orm(cls, feed: Gtfsrealtimefeed | None) -> GtfsRTFeed | None:
2524
gtfs_rt_feed: GtfsRTFeed = super().from_orm(feed)
2625
if not gtfs_rt_feed:
2726
return None
2827
gtfs_rt_feed.locations = [LocationImpl.from_orm(item) for item in feed.locations] if feed.locations else []
2928
gtfs_rt_feed.entity_types = [item.name for item in feed.entitytypes] if feed.entitytypes else []
30-
31-
provider_value = (feed.provider or "").strip().lower()
32-
33-
query = db_session.query(FeedOrm).filter(
34-
func.lower(func.trim(FeedOrm.provider)) == provider_value,
35-
FeedOrm.stable_id != feed.stable_id,
36-
)
37-
38-
gtfs_rt_feed.feed_references = [gtfs_feed.stable_id for gtfs_feed in query.all()]
39-
29+
gtfs_rt_feed.feed_references = [item.stable_id for item in feed.gtfs_feeds] if feed.gtfs_feeds else []
4030
return gtfs_rt_feed
4131

4232
@classmethod

0 commit comments

Comments
 (0)