Skip to content

Commit d705342

Browse files
committed
Added testing to export_csv
1 parent 16b8c4c commit d705342

File tree

9 files changed

+248
-99
lines changed

9 files changed

+248
-99
lines changed

api/src/shared/common/db_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,29 @@ def get_gtfs_feeds_query(
7474
return feed_query
7575

7676

77+
def get_all_gtfs_feeds_query(
78+
include_wip: bool = False,
79+
db_session: Session = None,
80+
) -> Query[any]:
81+
"""Get the DB query to use to retrieve all the GTFS feeds, filtering out the WIP is needed"""
82+
83+
feed_query = db_session.query(Gtfsfeed)
84+
85+
if not include_wip:
86+
feed_query = feed_query.filter(
87+
or_(Gtfsfeed.operational_status == None, Gtfsfeed.operational_status != "wip") # noqa: E711
88+
)
89+
90+
feed_query = feed_query.options(
91+
joinedload(Gtfsfeed.gtfsdatasets)
92+
.joinedload(Gtfsdataset.validation_reports)
93+
.joinedload(Validationreport.notices),
94+
*get_joinedload_options(),
95+
).order_by(Gtfsfeed.stable_id)
96+
97+
return feed_query
98+
99+
77100
def get_gtfs_rt_feeds_query(
78101
limit: int | None,
79102
offset: int | None,
@@ -137,6 +160,30 @@ def get_gtfs_rt_feeds_query(
137160
return feed_query
138161

139162

163+
def get_all_gtfs_rt_feeds_query(
164+
include_wip: bool = False,
165+
db_session: Session = None,
166+
) -> Query:
167+
"""Get the DB query to use to retrieve all the GTFS rt feeds, filtering out the WIP is needed"""
168+
feed_query = db_session.query(Gtfsrealtimefeed)
169+
170+
if not include_wip:
171+
feed_query = feed_query.filter(
172+
or_(
173+
Gtfsrealtimefeed.operational_status == None, # noqa: E711
174+
Gtfsrealtimefeed.operational_status != "wip",
175+
)
176+
)
177+
178+
feed_query = feed_query.options(
179+
joinedload(Gtfsrealtimefeed.entitytypes),
180+
joinedload(Gtfsrealtimefeed.gtfs_feeds),
181+
*get_joinedload_options(),
182+
).order_by(Gtfsfeed.stable_id)
183+
184+
return feed_query
185+
186+
140187
def apply_bounding_filtering(
141188
query: Query,
142189
bounding_latitudes: str,

functions-python/export_csv/function_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"description": "Export the DB feed data as a csv file",
44
"entry_point": "export_csv",
55
"timeout": 20,
6-
"memory": "256Mi",
6+
"memory": "1Gi",
77
"trigger_http": true,
88
"include_folders": ["helpers", "dataset_service"],
99
"include_api_folders": ["utils", "database", "feed_filters", "common", "database_gen"],

functions-python/export_csv/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ requests~=2.32.3
99
attrs~=23.1.0
1010
pluggy~=1.3.0
1111
certifi~=2024.7.4
12-
pandas
12+
pandas~=2.2.3
1313
python-dotenv==1.0.0
1414
fastapi-filter[sqlalchemy]==1.0.0
15+
packaging~=24.2
1516

1617
# SQL Alchemy and Geo Alchemy
1718
SQLAlchemy==2.0.23
@@ -22,3 +23,4 @@ shapely
2223
google-cloud-pubsub
2324
google-cloud-datastore
2425
cloudevents~=1.10.1
26+

functions-python/export_csv/src/main.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# limitations under the License.
1515
#
1616
import argparse
17-
import pandas as pd
1817
import os
18+
import re
19+
20+
import pandas as pd
1921

2022
from dotenv import load_dotenv
2123
import functions_framework
@@ -27,7 +29,7 @@
2729

2830
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed
2931
from collections import OrderedDict
30-
from shared.common.db_utils import get_gtfs_feeds_query, get_gtfs_rt_feeds_query
32+
from shared.common.db_utils import get_all_gtfs_rt_feeds_query, get_all_gtfs_feeds_query
3133

3234
from shared.helpers.database import Database
3335

@@ -57,10 +59,13 @@ def finalize_row(self):
5759
self.rows.append(self.data.copy())
5860
self.data = OrderedDict()
5961

60-
def write_csv(self, csv_file_path):
62+
def write_csv_to_file(self, csv_file_path):
6163
df = pd.DataFrame(self.rows, columns=self.headers)
6264
df.to_csv(csv_file_path, index=False)
6365

66+
def get_dataframe(self) -> pd:
67+
return pd.DataFrame(self.rows, columns=self.headers)
68+
6469

6570
@functions_framework.http
6671
def export_csv(request=None):
@@ -71,21 +76,20 @@ def export_csv(request=None):
7176
:param request: HTTP request object
7277
:return: HTTP response object
7378
"""
79+
data_collector = collect_data()
80+
data_collector.write_csv_to_file(csv_file_path)
81+
return f"Export of database feeds to CSV file {csv_file_path}."
82+
83+
84+
def collect_data() -> DataCollector:
85+
"""
86+
Collect data from the DB and write the output to a DataCollector.
87+
:return: A filled DataCollector
88+
"""
7489
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
7590
try:
7691
with db.start_db_session() as session:
77-
gtfs_feeds_query = get_gtfs_feeds_query(
78-
limit=None,
79-
offset=0,
80-
provider=None,
81-
producer_url=None,
82-
country_code=None,
83-
subdivision_name=None,
84-
municipality=None,
85-
dataset_latitudes=None,
86-
dataset_longitudes=None,
87-
bounding_filter_method=None,
88-
is_official=None,
92+
gtfs_feeds_query = get_all_gtfs_feeds_query(
8993
include_wip=False,
9094
db_session=session,
9195
)
@@ -94,16 +98,7 @@ def export_csv(request=None):
9498

9599
print(f"Retrieved {len(gtfs_feeds)} GTFS feeds.")
96100

97-
gtfs_rt_feeds_query = get_gtfs_rt_feeds_query(
98-
limit=None,
99-
offset=0,
100-
provider=None,
101-
producer_url=None,
102-
entity_types=None,
103-
country_code=None,
104-
subdivision_name=None,
105-
municipality=None,
106-
is_official=None,
101+
gtfs_rt_feeds_query = get_all_gtfs_rt_feeds_query(
107102
include_wip=False,
108103
db_session=session,
109104
)
@@ -134,11 +129,13 @@ def export_csv(request=None):
134129
except Exception as error:
135130
print(f"Error retrieving feeds: {error}")
136131
raise Exception(f"Error retrieving feeds: {error}")
132+
data_collector.write_csv_to_file(csv_file_path)
133+
return data_collector
137134

138-
data_collector.write_csv(csv_file_path)
139135

140-
print(f"Wrote {len(gtfs_feeds)} feeds to {csv_file_path}.")
141-
return f"Wrote {len(gtfs_feeds)} feeds to {csv_file_path}."
136+
def extract_numeric_version(version):
137+
match = re.match(r"(\d+\.\d+\.\d+)", version)
138+
return match.group(1) if match else version
142139

143140

144141
def get_feed_csv_data(feed: Gtfsfeed):
@@ -162,15 +159,19 @@ def get_feed_csv_data(feed: Gtfsfeed):
162159
# Keep the report from the more recent validator version
163160
latest_report = reduce(
164161
lambda a, b: a
165-
if Version(a.validator_version) > Version(b.validator_version)
162+
if Version(extract_numeric_version(a.validator_version))
163+
> Version(extract_numeric_version(b.validator_version))
166164
else b,
167165
latest_dataset.validation_reports,
168166
)
167+
169168
if latest_report:
170169
if latest_report.features:
171170
features = latest_report.features
172171
joined_features = (
173-
"|".join(feature.name for feature in features if feature.name)
172+
"|".join(
173+
sorted(feature.name for feature in features if feature.name)
174+
)
174175
if features
175176
else ""
176177
)
@@ -185,7 +186,7 @@ def get_feed_csv_data(feed: Gtfsfeed):
185186
maximum_longitude = shape.bounds[2]
186187

187188
data = {
188-
"mdb_source_id": feed.stable_id,
189+
"id": feed.stable_id,
189190
"data_type": feed.data_type,
190191
"entity_type": None,
191192
"location.country_code": ""
@@ -262,6 +263,7 @@ def get_gtfs_rt_feed_csv_data(feed: Gtfsrealtimefeed):
262263
for entity_type in feed.entitytypes
263264
if entity_type and entity_type.name
264265
]
266+
valid_entity_types = sorted(valid_entity_types)
265267
entity_types = "|".join(valid_entity_types)
266268

267269
static_references = ""
@@ -274,7 +276,7 @@ def get_gtfs_rt_feed_csv_data(feed: Gtfsrealtimefeed):
274276
static_references = "|".join(valid_feed_references)
275277

276278
data = {
277-
"mdb_source_id": feed.stable_id,
279+
"id": feed.stable_id,
278280
"data_type": feed.data_type,
279281
"entity_type": entity_types,
280282
"location.country_code": ""

0 commit comments

Comments
 (0)