Skip to content

Commit 1035eed

Browse files
authored
fix: reduce memory usage for export_csv (#900)
1 parent c568ce9 commit 1035eed

File tree

4 files changed

+82
-90
lines changed

4 files changed

+82
-90
lines changed

functions-python/export_csv/src/main.py

Lines changed: 74 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,69 +14,63 @@
1414
# limitations under the License.
1515
#
1616
import argparse
17+
import csv
1718
import logging
1819
import os
1920
import re
20-
21-
import pandas as pd
21+
from typing import Dict, Iterator
2222

2323
from dotenv import load_dotenv
2424
import functions_framework
2525

2626
from packaging.version import Version
27-
from functools import reduce
2827
from google.cloud import storage
2928
from geoalchemy2.shape import to_shape
3029

3130
from shared.helpers.logger import Logger
3231
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed
33-
from collections import OrderedDict
3432
from shared.common.db_utils import get_all_gtfs_rt_feeds_query, get_all_gtfs_feeds_query
3533

3634
from shared.helpers.database import Database
3735

3836
load_dotenv()
3937
csv_default_file_path = "./output.csv"
40-
csv_file_path = csv_default_file_path
41-
42-
43-
class DataCollector:
44-
"""
45-
A class used to collect and organize data into rows and headers for CSV output.
46-
One particularity of this class is that it uses an OrderedDict to store the data, so that the order of the columns
47-
is preserved when writing to CSV.
48-
"""
49-
50-
def __init__(self):
51-
self.data = OrderedDict()
52-
self.rows = []
53-
self.headers = []
54-
55-
def add_data(self, key, value):
56-
if key not in self.headers:
57-
self.headers.append(key)
58-
self.data[key] = value
5938

60-
def finalize_row(self):
61-
self.rows.append(self.data.copy())
62-
self.data = OrderedDict()
63-
64-
def write_csv_to_file(self, csv_file_path):
65-
df = pd.DataFrame(self.rows, columns=self.headers)
66-
df.to_csv(csv_file_path, index=False)
67-
68-
def get_dataframe(self) -> pd:
69-
return pd.DataFrame(self.rows, columns=self.headers)
39+
# This needs to be updated if we add fields to either `get_feed_csv_data` or
40+
# `get_gtfs_rt_feed_csv_data`, otherwise the extra field(s) will be excluded from
41+
# the generated CSV file.
42+
headers = [
43+
"id",
44+
"data_type",
45+
"entity_type",
46+
"location.country_code",
47+
"location.subdivision_name",
48+
"location.municipality",
49+
"provider",
50+
"name",
51+
"note",
52+
"feed_contact_email",
53+
"static_reference",
54+
"urls.direct_download",
55+
"urls.authentication_type",
56+
"urls.authentication_info",
57+
"urls.api_key_parameter_name",
58+
"urls.latest",
59+
"urls.license",
60+
"location.bounding_box.minimum_latitude",
61+
"location.bounding_box.maximum_latitude",
62+
"location.bounding_box.minimum_longitude",
63+
"location.bounding_box.maximum_longitude",
64+
"location.bounding_box.extracted_on",
65+
"status",
66+
"features",
67+
"redirect.id",
68+
"redirect.comment",
69+
]
7070

7171

7272
@functions_framework.http
7373
def export_and_upload_csv(request=None):
74-
response = export_csv()
75-
upload_file_to_storage(csv_file_path, "sources_v2.csv")
76-
return response
77-
78-
79-
def export_csv():
8074
"""
8175
HTTP Function entry point Reads the DB and outputs a csv file with feeds data.
8276
This function requires the following environment variables to be set:
@@ -85,16 +79,36 @@ def export_csv():
8579
:return: HTTP response object
8680
"""
8781
Logger.init_logger()
88-
logging.info("Function Started")
89-
data_collector = collect_data()
90-
data_collector.write_csv_to_file(csv_file_path)
91-
return f"Exported {len(data_collector.rows)} feeds to CSV file {csv_file_path}."
82+
logging.info("Export started")
83+
84+
csv_file_path = csv_default_file_path
85+
export_csv(csv_file_path)
86+
upload_file_to_storage(csv_file_path, "sources_v2.csv")
87+
88+
logging.info("Export successful")
89+
return "Export successful"
9290

9391

94-
def collect_data() -> DataCollector:
92+
def export_csv(csv_file_path: str):
9593
"""
96-
Collect data from the DB and write the output to a DataCollector.
97-
:return: A filled DataCollector
94+
Write feed data to a local CSV file.
95+
"""
96+
with open(csv_file_path, "w") as out:
97+
writer = csv.DictWriter(out, fieldnames=headers)
98+
writer.writeheader()
99+
100+
count = 0
101+
for feed in fetch_feeds():
102+
writer.writerow(feed)
103+
count += 1
104+
105+
logging.info(f"Exported {count} feeds to CSV file {csv_file_path}.")
106+
107+
108+
def fetch_feeds() -> Iterator[Dict]:
109+
"""
110+
Fetch and return feed data from the DB.
111+
:return: Data to write to the output CSV file.
98112
"""
99113
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
100114
logging.info(f"Using database {db.database_url}")
@@ -118,28 +132,19 @@ def collect_data() -> DataCollector:
118132

119133
logging.info(f"Retrieved {len(gtfs_rt_feeds)} GTFS realtime feeds.")
120134

121-
data_collector = DataCollector()
122-
123135
for feed in gtfs_feeds:
124-
data = get_feed_csv_data(feed)
136+
yield get_feed_csv_data(feed)
125137

126-
for key, value in data.items():
127-
data_collector.add_data(key, value)
128-
data_collector.finalize_row()
129138
logging.info(f"Processed {len(gtfs_feeds)} GTFS feeds.")
130139

131140
for feed in gtfs_rt_feeds:
132-
data = get_gtfs_rt_feed_csv_data(feed)
133-
for key, value in data.items():
134-
data_collector.add_data(key, value)
135-
data_collector.finalize_row()
141+
yield get_gtfs_rt_feed_csv_data(feed)
142+
136143
logging.info(f"Processed {len(gtfs_rt_feeds)} GTFS realtime feeds.")
137144

138145
except Exception as error:
139146
logging.error(f"Error retrieving feeds: {error}")
140147
raise Exception(f"Error retrieving feeds: {error}")
141-
data_collector.write_csv_to_file(csv_file_path)
142-
return data_collector
143148

144149

145150
def extract_numeric_version(version):
@@ -166,12 +171,9 @@ def get_feed_csv_data(feed: Gtfsfeed):
166171

167172
if latest_dataset and latest_dataset.validation_reports:
168173
# Keep the report from the more recent validator version
169-
latest_report = reduce(
170-
lambda a, b: a
171-
if Version(extract_numeric_version(a.validator_version))
172-
> Version(extract_numeric_version(b.validator_version))
173-
else b,
174+
latest_report = max(
174175
latest_dataset.validation_reports,
176+
key=lambda r: Version(extract_numeric_version(r.validator_version)),
175177
)
176178

177179
if latest_report:
@@ -234,8 +236,8 @@ def get_feed_csv_data(feed: Gtfsfeed):
234236
"location.bounding_box.maximum_latitude": maximum_latitude,
235237
"location.bounding_box.minimum_longitude": minimum_longitude,
236238
"location.bounding_box.maximum_longitude": maximum_longitude,
237-
"location.bounding_box.extracted_on": validated_at,
238239
# We use the report validated_at date as the extracted_on date
240+
"location.bounding_box.extracted_on": validated_at,
239241
"status": feed.status,
240242
"features": joined_features,
241243
}
@@ -348,13 +350,17 @@ def upload_file_to_storage(source_file_path, target_path):
348350

349351

350352
if __name__ == "__main__":
351-
parser = argparse.ArgumentParser(description="Export DB feed contents to csv.")
353+
parser = argparse.ArgumentParser(
354+
description="Export DB feed contents to csv.",
355+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
356+
)
352357
parser.add_argument(
353-
"--outpath", help="Path to the output csv file. Default is ./output.csv"
358+
"--outpath",
359+
default=csv_default_file_path,
360+
help="Path to the output csv file.",
354361
)
355362
os.environ[
356363
"FEEDS_DATABASE_URL"
357364
] = "postgresql://postgres:postgres@localhost:54320/MobilityDatabaseTest"
358365
args = parser.parse_args()
359-
csv_file_path = args.outpath if args.outpath else csv_default_file_path
360-
export_csv()
366+
export_csv(args.outpath)

functions-python/export_csv/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def populate_database():
5151
feeds = []
5252
# We create 3 feeds. The first one is active. The third one is inactive and redirected to the first one.
5353
# The second one is active but not redirected.
54-
# First fill the generic paramaters
54+
# First fill the generic parameters
5555
for i in range(3):
5656
feed = Gtfsfeed(
5757
data_type="gtfs",

functions-python/export_csv/tests/test_export_csv_main.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,14 @@ def test_export_csv():
3939
os.environ[
4040
"FEEDS_DATABASE_URL"
4141
] = "postgresql://postgres:postgres@localhost:54320/MobilityDatabaseTest"
42-
data_collector = main.collect_data()
43-
print(f"Collected data for {len(data_collector.rows)} feeds.")
4442

45-
df_extracted = data_collector.get_dataframe()
43+
csv_file_path = "./output.csv"
44+
main.export_csv(csv_file_path)
45+
df_actual = pd.read_csv(csv_file_path)
46+
print(f"Collected data for {len(df_actual)} feeds.")
4647

47-
csv_buffer = io.StringIO(expected_csv)
48-
df_from_expected_csv = pd.read_csv(csv_buffer)
49-
df_from_expected_csv.fillna("", inplace=True)
50-
51-
df_extracted.fillna("", inplace=True)
52-
53-
df_extracted["urls.authentication_type"] = df_extracted[
54-
"urls.authentication_type"
55-
].astype(str)
56-
df_from_expected_csv["urls.authentication_type"] = df_from_expected_csv[
57-
"urls.authentication_type"
58-
].astype(str)
59-
df_from_expected_csv["location.bounding_box.extracted_on"] = pd.to_datetime(
60-
df_from_expected_csv["location.bounding_box.extracted_on"], utc=True
61-
)
48+
df_expected = pd.read_csv(io.StringIO(expected_csv))
6249

6350
# try:
64-
pdt.assert_frame_equal(df_extracted, df_from_expected_csv)
51+
pdt.assert_frame_equal(df_actual, df_expected)
6552
print("DataFrames are equal.")

scripts/api-tests.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ execute_tests() {
109109

110110
# Generate coverage report
111111
current_dir_name=$(basename "$(pwd)")
112-
mkdir $ABS_SCRIPTPATH/coverage_reports
113-
mkdir $ABS_SCRIPTPATH/coverage_reports/$current_dir_name
112+
mkdir -p $ABS_SCRIPTPATH/coverage_reports/$current_dir_name
114113
venv/bin/coverage report > $ABS_SCRIPTPATH/coverage_reports/$current_dir_name/report.txt
115114
printf "\n${YELLOW}COVERAGE REPORT FOR $1:${NC}\n"
116115
cat $ABS_SCRIPTPATH/coverage_reports/$current_dir_name/report.txt

0 commit comments

Comments
 (0)