|
4 | 4 |
|
5 | 5 | import pycountry |
6 | 6 | import pytz |
| 7 | +from sqlalchemy import func |
7 | 8 |
|
8 | 9 | from scripts.load_dataset_on_create import publish_all |
9 | 10 | from scripts.populate_db import DatabasePopulateHelper, set_up_configs |
|
14 | 15 | Gtfsrealtimefeed, |
15 | 16 | Location, |
16 | 17 | Redirectingid, |
| 18 | + Gtfsfeed, |
17 | 19 | ) |
18 | 20 | from utils.data_utils import set_up_defaults |
19 | 21 |
|
@@ -116,32 +118,95 @@ def process_entity_types(self, session: "Session", feed: Gtfsrealtimefeed, row, |
116 | 118 | self.logger.warning(f"Entity types array is empty for feed {stable_id}") |
117 | 119 | feed.entitytypes.clear() |
118 | 120 |
|
| 121 | + # def process_feed_references(self, session: "Session"): |
| 122 | + # """ |
| 123 | + # Process the feed references |
| 124 | + # """ |
| 125 | + # self.logger.info("Processing feed references") |
| 126 | + # for index, row in self.df.iterrows(): |
| 127 | + # stable_id = self.get_stable_id(row) |
| 128 | + # data_type = self.get_data_type(row) |
| 129 | + # if data_type != "gtfs_rt": |
| 130 | + # continue |
| 131 | + # gtfs_rt_feed = self.query_feed_by_stable_id(session, stable_id, "gtfs_rt") |
| 132 | + # static_reference = self.get_safe_value(row, "static_reference", "") |
| 133 | + # if static_reference: |
| 134 | + # try: |
| 135 | + # gtfs_stable_id = f"mdb-{int(float(static_reference))}" |
| 136 | + # except ValueError: |
| 137 | + # gtfs_stable_id = static_reference |
| 138 | + # gtfs_feed = self.query_feed_by_stable_id(session, gtfs_stable_id, "gtfs") |
| 139 | + # if not gtfs_feed: |
| 140 | + # self.logger.warning(f"Could not find static reference feed {gtfs_stable_id} for feed {stable_id}") |
| 141 | + # continue |
| 142 | + # already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds} |
| 143 | + # if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids: |
| 144 | + # gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed) |
| 145 | + # # Flush to avoid FK violation |
| 146 | + # session.flush() |
| 147 | + |
119 | 148 | def process_feed_references(self, session: "Session"): |
120 | 149 | """ |
121 | | - Process the feed references |
| 150 | + Process the feed references for GTFS-RT feeds. |
| 151 | +
|
| 152 | + 1. Uses 'static_reference' column if present. |
| 153 | + 2. Falls back to matching static feeds by provider name. |
122 | 154 | """ |
123 | 155 | self.logger.info("Processing feed references") |
| 156 | + |
124 | 157 | for index, row in self.df.iterrows(): |
125 | 158 | stable_id = self.get_stable_id(row) |
126 | 159 | data_type = self.get_data_type(row) |
| 160 | + |
| 161 | + # Only process GTFS-RT feeds |
127 | 162 | if data_type != "gtfs_rt": |
128 | 163 | continue |
| 164 | + |
129 | 165 | gtfs_rt_feed = self.query_feed_by_stable_id(session, stable_id, "gtfs_rt") |
130 | | - static_reference = self.get_safe_value(row, "static_reference", "") |
| 166 | + if not gtfs_rt_feed: |
| 167 | + self.logger.warning(f"Could not find GTFS-RT feed {stable_id}") |
| 168 | + continue |
| 169 | + |
| 170 | + # Try static_reference column first |
| 171 | + static_reference = self.get_safe_value(row, "static_reference", "").strip() |
| 172 | + gtfs_feed = None |
| 173 | + |
131 | 174 | if static_reference: |
| 175 | + # Normalize stable_id |
132 | 176 | try: |
133 | 177 | gtfs_stable_id = f"mdb-{int(float(static_reference))}" |
134 | 178 | except ValueError: |
135 | 179 | gtfs_stable_id = static_reference |
| 180 | + |
136 | 181 | gtfs_feed = self.query_feed_by_stable_id(session, gtfs_stable_id, "gtfs") |
137 | 182 | if not gtfs_feed: |
138 | 183 | self.logger.warning(f"Could not find static reference feed {gtfs_stable_id} for feed {stable_id}") |
139 | | - continue |
| 184 | + |
| 185 | + # Fallback: match by provider if no static_reference or not found |
| 186 | + if not gtfs_feed: |
| 187 | + provider_value = (self.get_safe_value(row, "provider", "") or "").strip().lower() |
| 188 | + if provider_value: |
| 189 | + gtfs_feed = ( |
| 190 | + session.query(Gtfsfeed) |
| 191 | + .filter( |
| 192 | + Gtfsfeed.data_type == "gtfs", |
| 193 | + func.lower(func.trim(Gtfsfeed.provider)) == provider_value, |
| 194 | + Gtfsfeed.stable_id != stable_id, |
| 195 | + ) |
| 196 | + .first() |
| 197 | + ) |
| 198 | + if not gtfs_feed: |
| 199 | + self.logger.warning( |
| 200 | + f"No static GTFS feed found for provider '{provider_value}' for feed {stable_id}" |
| 201 | + ) |
| 202 | + |
| 203 | + # Link the feeds if we have a valid static GTFS feed |
| 204 | + if gtfs_feed: |
140 | 205 | already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds} |
141 | | - if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids: |
| 206 | + if gtfs_rt_feed.id not in already_referenced_ids: |
142 | 207 | gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed) |
143 | | - # Flush to avoid FK violation |
144 | | - session.flush() |
| 208 | + session.flush() # Avoid FK violations |
| 209 | + self.logger.info(f"Linked GTFS-RT feed {stable_id} to static feed {gtfs_feed.stable_id}") |
145 | 210 |
|
146 | 211 | def process_redirects(self, session: "Session"): |
147 | 212 | """ |
|
0 commit comments