|
| 1 | +from typing import Iterator |
| 2 | + |
1 | 3 | from geoalchemy2 import WKTElement |
| 4 | +from sqlalchemy import or_ |
2 | 5 | from sqlalchemy import select |
3 | 6 | from sqlalchemy.orm import joinedload, Session |
4 | 7 | from sqlalchemy.orm.query import Query |
|
14 | 17 | Entitytype, |
15 | 18 | Redirectingid, |
16 | 19 | ) |
17 | | - |
18 | 20 | from shared.feed_filters.gtfs_feed_filter import GtfsFeedFilter, LocationFilter |
19 | 21 | from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter |
20 | | - |
21 | 22 | from .entity_type_enum import EntityType |
22 | | - |
23 | | -from sqlalchemy import or_ |
24 | | - |
25 | 23 | from .error_handling import raise_internal_http_validation_error, invalid_bounding_coordinates, invalid_bounding_method |
| 24 | +from .iter_utils import batched |
26 | 25 |
|
27 | 26 |
|
28 | 27 | def get_gtfs_feeds_query( |
@@ -75,28 +74,39 @@ def get_gtfs_feeds_query( |
75 | 74 | return feed_query |
76 | 75 |
|
77 | 76 |
|
78 | | -def get_all_gtfs_feeds_query( |
| 77 | +def get_all_gtfs_feeds( |
| 78 | + db_session: Session, |
79 | 79 | include_wip: bool = False, |
80 | | - db_session: Session = None, |
81 | | -) -> Query[any]: |
82 | | - """Get the DB query to use to retrieve all the GTFS feeds, filtering out the WIP if needed""" |
83 | | - |
84 | | - feed_query = db_session.query(Gtfsfeed) |
85 | | - |
| 80 | + batch_size: int = 250, |
| 81 | +) -> Iterator[Gtfsfeed]: |
| 82 | + """ |
| 83 | + Fetch all GTFS feeds. |
| 84 | +
|
| 85 | + @param db_session: The database session. |
| 86 | + @param include_wip: Whether to include or exclude WIP feeds. |
| 87 | + @param batch_size: The number of feeds to fetch from the database at a time. |
| 88 | + A lower value means less memory but more queries. |
| 89 | +
|
| 90 | + @return: The GTFS feeds in an iterator. |
| 91 | + """ |
| 92 | + feed_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size) |
86 | 93 | if not include_wip: |
87 | | - feed_query = feed_query.filter( |
88 | | - or_(Gtfsfeed.operational_status == None, Gtfsfeed.operational_status != "wip") # noqa: E711 |
| 94 | + feed_query = feed_query.filter(Gtfsfeed.operational_status.is_distinct_from("wip")) |
| 95 | + |
| 96 | + for batch in batched(feed_query, batch_size): |
| 97 | + stable_ids = (f.stable_id for f in batch) |
| 98 | + yield from ( |
| 99 | + db_session.query(Gtfsfeed) |
| 100 | + .filter(Gtfsfeed.stable_id.in_(stable_ids)) |
| 101 | + .options( |
| 102 | + joinedload(Gtfsfeed.gtfsdatasets) |
| 103 | + .joinedload(Gtfsdataset.validation_reports) |
| 104 | + .joinedload(Validationreport.features), |
| 105 | + *get_joinedload_options(), |
| 106 | + ) |
| 107 | + .order_by(Gtfsfeed.stable_id) |
89 | 108 | ) |
90 | 109 |
|
91 | | - feed_query = feed_query.options( |
92 | | - joinedload(Gtfsfeed.gtfsdatasets) |
93 | | - .joinedload(Gtfsdataset.validation_reports) |
94 | | - .joinedload(Validationreport.features), |
95 | | - *get_joinedload_options(), |
96 | | - ).order_by(Gtfsfeed.stable_id) |
97 | | - |
98 | | - return feed_query |
99 | | - |
100 | 110 |
|
101 | 111 | def get_gtfs_rt_feeds_query( |
102 | 112 | limit: int | None, |
@@ -161,29 +171,38 @@ def get_gtfs_rt_feeds_query( |
161 | 171 | return feed_query |
162 | 172 |
|
163 | 173 |
|
164 | | -def get_all_gtfs_rt_feeds_query( |
| 174 | +def get_all_gtfs_rt_feeds( |
| 175 | + db_session: Session, |
165 | 176 | include_wip: bool = False, |
166 | | - db_session: Session = None, |
167 | | -) -> Query: |
168 | | - """Get the DB query to use to retrieve all the GTFS rt feeds, filtering out the WIP if needed""" |
169 | | - feed_query = db_session.query(Gtfsrealtimefeed) |
170 | | - |
| 177 | + batch_size: int = 250, |
| 178 | +) -> Iterator[Gtfsrealtimefeed]: |
| 179 | + """ |
| 180 | + Fetch all GTFS realtime feeds. |
| 181 | +
|
| 182 | + @param db_session: The database session. |
| 183 | + @param include_wip: Whether to include or exclude WIP feeds. |
| 184 | + @param batch_size: The number of feeds to fetch from the database at a time. |
| 185 | + A lower value means less memory but more queries. |
| 186 | +
|
| 187 | + @return: The GTFS realtime feeds in an iterator. |
| 188 | + """ |
| 189 | + feed_query = db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size) |
171 | 190 | if not include_wip: |
172 | | - feed_query = feed_query.filter( |
173 | | - or_( |
174 | | - Gtfsrealtimefeed.operational_status == None, # noqa: E711 |
175 | | - Gtfsrealtimefeed.operational_status != "wip", |
| 191 | + feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status.is_distinct_from("wip")) |
| 192 | + |
| 193 | + for batch in batched(feed_query, batch_size): |
| 194 | + stable_ids = (f.stable_id for f in batch) |
| 195 | + yield from ( |
| 196 | + db_session.query(Gtfsrealtimefeed) |
| 197 | + .filter(Gtfsrealtimefeed.stable_id.in_(stable_ids)) |
| 198 | + .options( |
| 199 | + joinedload(Gtfsrealtimefeed.entitytypes), |
| 200 | + joinedload(Gtfsrealtimefeed.gtfs_feeds), |
| 201 | + *get_joinedload_options(), |
176 | 202 | ) |
| 203 | + .order_by(Gtfsfeed.stable_id) |
177 | 204 | ) |
178 | 205 |
|
179 | | - feed_query = feed_query.options( |
180 | | - joinedload(Gtfsrealtimefeed.entitytypes), |
181 | | - joinedload(Gtfsrealtimefeed.gtfs_feeds), |
182 | | - *get_joinedload_options(), |
183 | | - ).order_by(Gtfsfeed.stable_id) |
184 | | - |
185 | | - return feed_query |
186 | | - |
187 | 206 |
|
188 | 207 | def apply_bounding_filtering( |
189 | 208 | query: Query, |
|
0 commit comments