Skip to content

Commit 33ffb82

Browse files
authored
feat: use the location with the most stops when exporting the csv (#1096)
1 parent 9864c9f commit 33ffb82

File tree

8 files changed

+363
-89
lines changed

8 files changed

+363
-89
lines changed

api/src/shared/common/db_utils.py

Lines changed: 137 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Iterator
1+
from typing import Iterator, List, Dict
22

33
from geoalchemy2 import WKTElement
44
from sqlalchemy import or_
55
from sqlalchemy import select
6-
from sqlalchemy.orm import joinedload, Session, contains_eager
6+
from sqlalchemy.orm import joinedload, Session, contains_eager, load_only
77
from sqlalchemy.orm.query import Query
88
from sqlalchemy.orm.strategy_options import _AbstractLoad
9-
9+
from sqlalchemy import func
1010
from shared.database_gen.sqlacodegen_models import (
1111
Feed,
1212
Gtfsdataset,
@@ -16,6 +16,8 @@
1616
Gtfsrealtimefeed,
1717
Entitytype,
1818
Redirectingid,
19+
Feedosmlocationgroup,
20+
Geopolygon,
1921
)
2022
from shared.feed_filters.gtfs_feed_filter import GtfsFeedFilter, LocationFilter
2123
from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter
@@ -86,40 +88,111 @@ def get_gtfs_feeds_query(
8688
return feed_query
8789

8890

91+
def apply_most_common_location_filter(query: Query, db_session: Session) -> Query:
92+
"""
93+
Apply the most common location filter to the query.
94+
:param query: The query to apply the filter to.
95+
:param db_session: The database session.
96+
97+
:return: The query with the most common location filter applied.
98+
"""
99+
most_common_location_subquery = (
100+
db_session.query(
101+
Feedosmlocationgroup.feed_id, func.max(Feedosmlocationgroup.stops_count).label("max_stops_count")
102+
)
103+
.group_by(Feedosmlocationgroup.feed_id)
104+
.subquery()
105+
)
106+
return query.outerjoin(Feed.feedosmlocationgroups).filter(
107+
Feedosmlocationgroup.stops_count == most_common_location_subquery.c.max_stops_count,
108+
Feedosmlocationgroup.feed_id == most_common_location_subquery.c.feed_id,
109+
)
110+
111+
112+
def get_geopolygons(db_session: Session, feeds: List[Feed], include_geometry: bool = False) -> Dict[str, Geopolygon]:
113+
"""
114+
Get the geolocations for the given feeds.
115+
:param db_session: The database session.
116+
:param feeds: The feeds to get the geolocations for.
117+
:param include_geometry: Whether to include the geometry in the result.
118+
119+
:return: The geolocations for the given location groups.
120+
"""
121+
location_groups = [feed.feedosmlocationgroups for feed in feeds]
122+
location_groups = [item for sublist in location_groups for item in sublist]
123+
124+
if not location_groups:
125+
return dict()
126+
geo_polygons_osm_ids = []
127+
for location_group in location_groups:
128+
split_ids = location_group.group_id.split(".")
129+
if not split_ids:
130+
continue
131+
geo_polygons_osm_ids += [int(split_id) for split_id in split_ids if split_id.isdigit()]
132+
if not geo_polygons_osm_ids:
133+
return dict()
134+
geo_polygons_osm_ids = list(set(geo_polygons_osm_ids))
135+
query = db_session.query(Geopolygon).filter(Geopolygon.osm_id.in_(geo_polygons_osm_ids))
136+
if not include_geometry:
137+
query = query.options(
138+
load_only(Geopolygon.osm_id, Geopolygon.name, Geopolygon.iso_3166_2_code, Geopolygon.iso_3166_1_code)
139+
)
140+
query = query.order_by(Geopolygon.admin_level)
141+
geopolygons = query.all()
142+
geopolygon_map = {str(geopolygon.osm_id): geopolygon for geopolygon in geopolygons}
143+
return geopolygon_map
144+
145+
89146
def get_all_gtfs_feeds(
90147
db_session: Session,
91148
published_only: bool = True,
92149
batch_size: int = 250,
150+
w_extracted_locations_only: bool = False,
93151
) -> Iterator[Gtfsfeed]:
94152
"""
95153
Fetch all GTFS feeds.
96154
97-
@param db_session: The database session.
98-
@param published_only: Include only the published feeds.
99-
@param batch_size: The number of feeds to fetch from the database at a time.
155+
:param db_session: The database session.
156+
:param published_only: Include only the published feeds.
157+
:param batch_size: The number of feeds to fetch from the database at a time.
100158
A lower value means less memory but more queries.
159+
:param w_extracted_locations_only: Whether to include only feeds with extracted locations.
101160
102-
@return: The GTFS feeds in an iterator.
161+
:return: The GTFS feeds in an iterator.
103162
"""
104-
feed_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size)
163+
batch_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size)
105164
if published_only:
106-
feed_query = feed_query.filter(Gtfsfeed.operational_status == "published")
165+
batch_query = batch_query.filter(Gtfsfeed.operational_status == "published")
107166

108-
for batch in batched(feed_query, batch_size):
167+
for batch in batched(batch_query, batch_size):
109168
stable_ids = (f.stable_id for f in batch)
110-
yield from (
111-
db_session.query(Gtfsfeed)
112-
.outerjoin(Gtfsfeed.gtfsdatasets)
113-
.filter(Gtfsfeed.stable_id.in_(stable_ids))
114-
.filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711
115-
.options(
116-
contains_eager(Gtfsfeed.gtfsdatasets)
117-
.joinedload(Gtfsdataset.validation_reports)
118-
.joinedload(Validationreport.features),
119-
*get_joinedload_options(),
169+
if w_extracted_locations_only:
170+
feed_query = apply_most_common_location_filter(
171+
db_session.query(Gtfsfeed).outerjoin(Gtfsfeed.gtfsdatasets), db_session
172+
)
173+
yield from (
174+
feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids))
175+
.filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711
176+
.options(
177+
contains_eager(Gtfsfeed.gtfsdatasets)
178+
.joinedload(Gtfsdataset.validation_reports)
179+
.joinedload(Validationreport.features),
180+
*get_joinedload_options(include_extracted_location_entities=True),
181+
)
182+
)
183+
else:
184+
yield from (
185+
db_session.query(Gtfsfeed)
186+
.outerjoin(Gtfsfeed.gtfsdatasets)
187+
.filter(Gtfsfeed.stable_id.in_(stable_ids))
188+
.filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711
189+
.options(
190+
contains_eager(Gtfsfeed.gtfsdatasets)
191+
.joinedload(Gtfsdataset.validation_reports)
192+
.joinedload(Validationreport.features),
193+
*get_joinedload_options(include_extracted_location_entities=False),
194+
)
120195
)
121-
.order_by(Gtfsfeed.stable_id)
122-
)
123196

124197

125198
def get_gtfs_rt_feeds_query(
@@ -196,33 +269,48 @@ def get_all_gtfs_rt_feeds(
196269
db_session: Session,
197270
published_only: bool = True,
198271
batch_size: int = 250,
272+
w_extracted_locations_only: bool = False,
199273
) -> Iterator[Gtfsrealtimefeed]:
200274
"""
201275
Fetch all GTFS realtime feeds.
202276
203-
@param db_session: The database session.
204-
@param published_only: Include only the published feeds.
205-
@param batch_size: The number of feeds to fetch from the database at a time.
277+
:param db_session: The database session.
278+
:param published_only: Include only the published feeds.
279+
:param batch_size: The number of feeds to fetch from the database at a time.
206280
A lower value means less memory but more queries.
281+
:param w_extracted_locations_only: Whether to include only feeds with extracted locations.
207282
208-
@return: The GTFS realtime feeds in an iterator.
283+
:return: The GTFS realtime feeds in an iterator.
209284
"""
210-
feed_query = db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size)
285+
batched_query = (
286+
db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size)
287+
)
211288
if published_only:
212-
feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status == "published")
289+
batched_query = batched_query.filter(Gtfsrealtimefeed.operational_status == "published")
213290

214-
for batch in batched(feed_query, batch_size):
291+
for batch in batched(batched_query, batch_size):
215292
stable_ids = (f.stable_id for f in batch)
216-
yield from (
217-
db_session.query(Gtfsrealtimefeed)
218-
.filter(Gtfsrealtimefeed.stable_id.in_(stable_ids))
219-
.options(
220-
joinedload(Gtfsrealtimefeed.entitytypes),
221-
joinedload(Gtfsrealtimefeed.gtfs_feeds),
222-
*get_joinedload_options(),
293+
if w_extracted_locations_only:
294+
feed_query = apply_most_common_location_filter(db_session.query(Gtfsrealtimefeed), db_session)
295+
yield from (
296+
feed_query.filter(Gtfsrealtimefeed.stable_id.in_(stable_ids))
297+
.options(
298+
joinedload(Gtfsrealtimefeed.entitytypes),
299+
joinedload(Gtfsrealtimefeed.gtfs_feeds),
300+
*get_joinedload_options(include_extracted_location_entities=True),
301+
)
302+
.order_by(Gtfsfeed.stable_id)
303+
)
304+
else:
305+
yield from (
306+
db_session.query(Gtfsrealtimefeed)
307+
.filter(Gtfsrealtimefeed.stable_id.in_(stable_ids))
308+
.options(
309+
joinedload(Gtfsrealtimefeed.entitytypes),
310+
joinedload(Gtfsrealtimefeed.gtfs_feeds),
311+
*get_joinedload_options(include_extracted_location_entities=False),
312+
)
223313
)
224-
.order_by(Gtfsfeed.stable_id)
225-
)
226314

227315

228316
def apply_bounding_filtering(
@@ -282,9 +370,17 @@ def apply_bounding_filtering(
282370
raise_internal_http_validation_error(invalid_bounding_method.format(bounding_filter_method))
283371

284372

285-
def get_joinedload_options() -> [_AbstractLoad]:
286-
"""Returns common joinedload options for feeds queries."""
287-
return [
373+
def get_joinedload_options(include_extracted_location_entities: bool = False) -> [_AbstractLoad]:
374+
"""
375+
Returns common joinedload options for feeds queries.
376+
:param include_extracted_location_entities: Whether to include extracted location entities.
377+
378+
:return: A list of joinedload options.
379+
"""
380+
joinedload_options = []
381+
if include_extracted_location_entities:
382+
joinedload_options = [contains_eager(Feed.feedosmlocationgroups).joinedload(Feedosmlocationgroup.group)]
383+
return joinedload_options + [
288384
joinedload(Feed.locations),
289385
joinedload(Feed.externalids),
290386
joinedload(Feed.redirectingids).joinedload(Redirectingid.target),
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Export and Upload Feeds to CSV
2+
This cloud function reads feed data from the database, processes it to extract relevant information, exports it to a CSV file, and uploads the file to a specified Google Cloud Storage bucket.
3+
4+
## Overview
5+
The function performs the following steps:
6+
1. Retrieves GTFS and GTFS-RT feeds from the database.
7+
2. Processes each feed to extract essential details, including location, provider, URLs, and features.
8+
3. Exports the processed data to a local CSV file.
9+
4. Uploads the CSV file to a Google Cloud Storage bucket.
10+
5. Returns an HTTP response indicating the success or failure of the operation.
11+
12+
## Project Structure
13+
14+
- **`main.py`**: The main file containing the cloud function implementation and utility functions.
15+
16+
## Function Configuration
17+
The function requires the following environment variables to be set:
18+
- `FEEDS_DATABASE_URL`: URL to access the feeds database.
19+
- `DATASETS_BUCKET_NAME`: Name of the Google Cloud Storage bucket.
20+
21+
## Local Development
22+
23+
For local development, follow the same steps as for other functions in the project. Please refer to the [README.md](../README.md) file in the parent directory for detailed instructions.

functions-python/export_csv/requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,7 @@ shapely
2020
# Google
2121
google-cloud-storage
2222
functions-framework==3.*
23-
google-cloud-logging
23+
google-cloud-logging
24+
25+
# Other dependencies
26+
natsort

0 commit comments

Comments
 (0)