|
1 | | -from typing import Iterator |
| 1 | +from typing import Iterator, List, Dict |
2 | 2 |
|
3 | 3 | from geoalchemy2 import WKTElement |
4 | 4 | from sqlalchemy import or_ |
5 | 5 | from sqlalchemy import select |
6 | | -from sqlalchemy.orm import joinedload, Session, contains_eager |
| 6 | +from sqlalchemy.orm import joinedload, Session, contains_eager, load_only |
7 | 7 | from sqlalchemy.orm.query import Query |
8 | 8 | from sqlalchemy.orm.strategy_options import _AbstractLoad |
9 | | - |
| 9 | +from sqlalchemy import func |
10 | 10 | from shared.database_gen.sqlacodegen_models import ( |
11 | 11 | Feed, |
12 | 12 | Gtfsdataset, |
|
16 | 16 | Gtfsrealtimefeed, |
17 | 17 | Entitytype, |
18 | 18 | Redirectingid, |
| 19 | + Feedosmlocationgroup, |
| 20 | + Geopolygon, |
19 | 21 | ) |
20 | 22 | from shared.feed_filters.gtfs_feed_filter import GtfsFeedFilter, LocationFilter |
21 | 23 | from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter |
@@ -86,40 +88,111 @@ def get_gtfs_feeds_query( |
86 | 88 | return feed_query |
87 | 89 |
|
88 | 90 |
|
| 91 | +def apply_most_common_location_filter(query: Query, db_session: Session) -> Query: |
| 92 | + """ |
| 93 | + Apply the most common location filter to the query. |
| 94 | + :param query: The query to apply the filter to. |
| 95 | + :param db_session: The database session. |
| 96 | +
|
| 97 | + :return: The query with the most common location filter applied. |
| 98 | + """ |
| 99 | + most_common_location_subquery = ( |
| 100 | + db_session.query( |
| 101 | + Feedosmlocationgroup.feed_id, func.max(Feedosmlocationgroup.stops_count).label("max_stops_count") |
| 102 | + ) |
| 103 | + .group_by(Feedosmlocationgroup.feed_id) |
| 104 | + .subquery() |
| 105 | + ) |
| 106 | + return query.outerjoin(Feed.feedosmlocationgroups).filter( |
| 107 | + Feedosmlocationgroup.stops_count == most_common_location_subquery.c.max_stops_count, |
| 108 | + Feedosmlocationgroup.feed_id == most_common_location_subquery.c.feed_id, |
| 109 | + ) |
| 110 | + |
| 111 | + |
| 112 | +def get_geopolygons(db_session: Session, feeds: List[Feed], include_geometry: bool = False) -> Dict[str, Geopolygon]: |
| 113 | + """ |
| 114 | + Get the geolocations for the given feeds. |
| 115 | + :param db_session: The database session. |
| 116 | + :param feeds: The feeds to get the geolocations for. |
| 117 | + :param include_geometry: Whether to include the geometry in the result. |
| 118 | +
|
| 119 | + :return: The geolocations for the given location groups. |
| 120 | + """ |
| 121 | + location_groups = [feed.feedosmlocationgroups for feed in feeds] |
| 122 | + location_groups = [item for sublist in location_groups for item in sublist] |
| 123 | + |
| 124 | + if not location_groups: |
| 125 | + return dict() |
| 126 | + geo_polygons_osm_ids = [] |
| 127 | + for location_group in location_groups: |
| 128 | + split_ids = location_group.group_id.split(".") |
| 129 | + if not split_ids: |
| 130 | + continue |
| 131 | + geo_polygons_osm_ids += [int(split_id) for split_id in split_ids if split_id.isdigit()] |
| 132 | + if not geo_polygons_osm_ids: |
| 133 | + return dict() |
| 134 | + geo_polygons_osm_ids = list(set(geo_polygons_osm_ids)) |
| 135 | + query = db_session.query(Geopolygon).filter(Geopolygon.osm_id.in_(geo_polygons_osm_ids)) |
| 136 | + if not include_geometry: |
| 137 | + query = query.options( |
| 138 | + load_only(Geopolygon.osm_id, Geopolygon.name, Geopolygon.iso_3166_2_code, Geopolygon.iso_3166_1_code) |
| 139 | + ) |
| 140 | + query = query.order_by(Geopolygon.admin_level) |
| 141 | + geopolygons = query.all() |
| 142 | + geopolygon_map = {str(geopolygon.osm_id): geopolygon for geopolygon in geopolygons} |
| 143 | + return geopolygon_map |
| 144 | + |
| 145 | + |
89 | 146 | def get_all_gtfs_feeds( |
90 | 147 | db_session: Session, |
91 | 148 | published_only: bool = True, |
92 | 149 | batch_size: int = 250, |
| 150 | + w_extracted_locations_only: bool = False, |
93 | 151 | ) -> Iterator[Gtfsfeed]: |
94 | 152 | """ |
95 | 153 | Fetch all GTFS feeds. |
96 | 154 |
|
97 | | - @param db_session: The database session. |
98 | | - @param published_only: Include only the published feeds. |
99 | | - @param batch_size: The number of feeds to fetch from the database at a time. |
| 155 | + :param db_session: The database session. |
| 156 | + :param published_only: Include only the published feeds. |
| 157 | + :param batch_size: The number of feeds to fetch from the database at a time. |
100 | 158 | A lower value means less memory but more queries. |
| 159 | + :param w_extracted_locations_only: Whether to include only feeds with extracted locations. |
101 | 160 |
|
102 | | - @return: The GTFS feeds in an iterator. |
| 161 | + :return: The GTFS feeds in an iterator. |
103 | 162 | """ |
104 | | - feed_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size) |
| 163 | + batch_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size) |
105 | 164 | if published_only: |
106 | | - feed_query = feed_query.filter(Gtfsfeed.operational_status == "published") |
| 165 | + batch_query = batch_query.filter(Gtfsfeed.operational_status == "published") |
107 | 166 |
|
108 | | - for batch in batched(feed_query, batch_size): |
| 167 | + for batch in batched(batch_query, batch_size): |
109 | 168 | stable_ids = (f.stable_id for f in batch) |
110 | | - yield from ( |
111 | | - db_session.query(Gtfsfeed) |
112 | | - .outerjoin(Gtfsfeed.gtfsdatasets) |
113 | | - .filter(Gtfsfeed.stable_id.in_(stable_ids)) |
114 | | - .filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711 |
115 | | - .options( |
116 | | - contains_eager(Gtfsfeed.gtfsdatasets) |
117 | | - .joinedload(Gtfsdataset.validation_reports) |
118 | | - .joinedload(Validationreport.features), |
119 | | - *get_joinedload_options(), |
| 169 | + if w_extracted_locations_only: |
| 170 | + feed_query = apply_most_common_location_filter( |
| 171 | + db_session.query(Gtfsfeed).outerjoin(Gtfsfeed.gtfsdatasets), db_session |
| 172 | + ) |
| 173 | + yield from ( |
| 174 | + feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids)) |
| 175 | + .filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711 |
| 176 | + .options( |
| 177 | + contains_eager(Gtfsfeed.gtfsdatasets) |
| 178 | + .joinedload(Gtfsdataset.validation_reports) |
| 179 | + .joinedload(Validationreport.features), |
| 180 | + *get_joinedload_options(include_extracted_location_entities=True), |
| 181 | + ) |
| 182 | + ) |
| 183 | + else: |
| 184 | + yield from ( |
| 185 | + db_session.query(Gtfsfeed) |
| 186 | + .outerjoin(Gtfsfeed.gtfsdatasets) |
| 187 | + .filter(Gtfsfeed.stable_id.in_(stable_ids)) |
| 188 | + .filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711 |
| 189 | + .options( |
| 190 | + contains_eager(Gtfsfeed.gtfsdatasets) |
| 191 | + .joinedload(Gtfsdataset.validation_reports) |
| 192 | + .joinedload(Validationreport.features), |
| 193 | + *get_joinedload_options(include_extracted_location_entities=False), |
| 194 | + ) |
120 | 195 | ) |
121 | | - .order_by(Gtfsfeed.stable_id) |
122 | | - ) |
123 | 196 |
|
124 | 197 |
|
125 | 198 | def get_gtfs_rt_feeds_query( |
@@ -196,33 +269,48 @@ def get_all_gtfs_rt_feeds( |
196 | 269 | db_session: Session, |
197 | 270 | published_only: bool = True, |
198 | 271 | batch_size: int = 250, |
| 272 | + w_extracted_locations_only: bool = False, |
199 | 273 | ) -> Iterator[Gtfsrealtimefeed]: |
200 | 274 | """ |
201 | 275 | Fetch all GTFS realtime feeds. |
202 | 276 |
|
203 | | - @param db_session: The database session. |
204 | | - @param published_only: Include only the published feeds. |
205 | | - @param batch_size: The number of feeds to fetch from the database at a time. |
| 277 | + :param db_session: The database session. |
| 278 | + :param published_only: Include only the published feeds. |
| 279 | + :param batch_size: The number of feeds to fetch from the database at a time. |
206 | 280 | A lower value means less memory but more queries. |
| 281 | + :param w_extracted_locations_only: Whether to include only feeds with extracted locations. |
207 | 282 |
|
208 | | - @return: The GTFS realtime feeds in an iterator. |
| 283 | + :return: The GTFS realtime feeds in an iterator. |
209 | 284 | """ |
210 | | - feed_query = db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size) |
| 285 | + batched_query = ( |
| 286 | + db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size) |
| 287 | + ) |
211 | 288 | if published_only: |
212 | | - feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status == "published") |
| 289 | + batched_query = batched_query.filter(Gtfsrealtimefeed.operational_status == "published") |
213 | 290 |
|
214 | | - for batch in batched(feed_query, batch_size): |
| 291 | + for batch in batched(batched_query, batch_size): |
215 | 292 | stable_ids = (f.stable_id for f in batch) |
216 | | - yield from ( |
217 | | - db_session.query(Gtfsrealtimefeed) |
218 | | - .filter(Gtfsrealtimefeed.stable_id.in_(stable_ids)) |
219 | | - .options( |
220 | | - joinedload(Gtfsrealtimefeed.entitytypes), |
221 | | - joinedload(Gtfsrealtimefeed.gtfs_feeds), |
222 | | - *get_joinedload_options(), |
| 293 | + if w_extracted_locations_only: |
| 294 | + feed_query = apply_most_common_location_filter(db_session.query(Gtfsrealtimefeed), db_session) |
| 295 | + yield from ( |
| 296 | + feed_query.filter(Gtfsrealtimefeed.stable_id.in_(stable_ids)) |
| 297 | + .options( |
| 298 | + joinedload(Gtfsrealtimefeed.entitytypes), |
| 299 | + joinedload(Gtfsrealtimefeed.gtfs_feeds), |
| 300 | + *get_joinedload_options(include_extracted_location_entities=True), |
| 301 | + ) |
| 302 | + .order_by(Gtfsfeed.stable_id) |
| 303 | + ) |
| 304 | + else: |
| 305 | + yield from ( |
| 306 | + db_session.query(Gtfsrealtimefeed) |
| 307 | + .filter(Gtfsrealtimefeed.stable_id.in_(stable_ids)) |
| 308 | + .options( |
| 309 | + joinedload(Gtfsrealtimefeed.entitytypes), |
| 310 | + joinedload(Gtfsrealtimefeed.gtfs_feeds), |
| 311 | + *get_joinedload_options(include_extracted_location_entities=False), |
| 312 | + ) |
223 | 313 | ) |
224 | | - .order_by(Gtfsfeed.stable_id) |
225 | | - ) |
226 | 314 |
|
227 | 315 |
|
228 | 316 | def apply_bounding_filtering( |
@@ -282,9 +370,17 @@ def apply_bounding_filtering( |
282 | 370 | raise_internal_http_validation_error(invalid_bounding_method.format(bounding_filter_method)) |
283 | 371 |
|
284 | 372 |
|
285 | | -def get_joinedload_options() -> [_AbstractLoad]: |
286 | | - """Returns common joinedload options for feeds queries.""" |
287 | | - return [ |
| 373 | +def get_joinedload_options(include_extracted_location_entities: bool = False) -> [_AbstractLoad]: |
| 374 | + """ |
| 375 | + Returns common joinedload options for feeds queries. |
| 376 | + :param include_extracted_location_entities: Whether to include extracted location entities. |
| 377 | +
|
| 378 | + :return: A list of joinedload options. |
| 379 | + """ |
| 380 | + joinedload_options = [] |
| 381 | + if include_extracted_location_entities: |
| 382 | + joinedload_options = [contains_eager(Feed.feedosmlocationgroups).joinedload(Feedosmlocationgroup.group)] |
| 383 | + return joinedload_options + [ |
288 | 384 | joinedload(Feed.locations), |
289 | 385 | joinedload(Feed.externalids), |
290 | 386 | joinedload(Feed.redirectingids).joinedload(Redirectingid.target), |
|
0 commit comments