Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 74 additions & 68 deletions functions-python/export_csv/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,69 +14,63 @@
# limitations under the License.
#
import argparse
import csv
import logging
import os
import re

import pandas as pd
from typing import Dict, Iterator

from dotenv import load_dotenv
import functions_framework

from packaging.version import Version
from functools import reduce
from google.cloud import storage
from geoalchemy2.shape import to_shape

from shared.helpers.logger import Logger
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed
from collections import OrderedDict
from shared.common.db_utils import get_all_gtfs_rt_feeds_query, get_all_gtfs_feeds_query

from shared.helpers.database import Database

load_dotenv()
csv_default_file_path = "./output.csv"
csv_file_path = csv_default_file_path


class DataCollector:
"""
A class used to collect and organize data into rows and headers for CSV output.
One particularity of this class is that it uses an OrderedDict to store the data, so that the order of the columns
is preserved when writing to CSV.
"""

def __init__(self):
self.data = OrderedDict()
self.rows = []
self.headers = []

def add_data(self, key, value):
if key not in self.headers:
self.headers.append(key)
self.data[key] = value

def finalize_row(self):
self.rows.append(self.data.copy())
self.data = OrderedDict()

def write_csv_to_file(self, csv_file_path):
df = pd.DataFrame(self.rows, columns=self.headers)
df.to_csv(csv_file_path, index=False)

def get_dataframe(self) -> pd:
return pd.DataFrame(self.rows, columns=self.headers)
# This needs to be updated if we add fields to either `get_feed_csv_data` or
# `get_gtfs_rt_feed_csv_data`, otherwise the extra field(s) will be excluded from
# the generated CSV file.
headers = [
"id",
"data_type",
"entity_type",
"location.country_code",
"location.subdivision_name",
"location.municipality",
"provider",
"name",
"note",
"feed_contact_email",
"static_reference",
"urls.direct_download",
"urls.authentication_type",
"urls.authentication_info",
"urls.api_key_parameter_name",
"urls.latest",
"urls.license",
"location.bounding_box.minimum_latitude",
"location.bounding_box.maximum_latitude",
"location.bounding_box.minimum_longitude",
"location.bounding_box.maximum_longitude",
"location.bounding_box.extracted_on",
"status",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcpitre I noticed we don't set the status for realtime feeds, is that intentional?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not intentional, the realtime feeds also should have the status column populated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a specific issue to address this, as we are missing a few other column values. #906

"features",
"redirect.id",
"redirect.comment",
]


@functions_framework.http
def export_and_upload_csv(request=None):
response = export_csv()
upload_file_to_storage(csv_file_path, "sources_v2.csv")
return response


def export_csv():
"""
HTTP Function entry point Reads the DB and outputs a csv file with feeds data.
This function requires the following environment variables to be set:
Expand All @@ -85,16 +79,36 @@ def export_csv():
:return: HTTP response object
"""
Logger.init_logger()
logging.info("Function Started")
data_collector = collect_data()
data_collector.write_csv_to_file(csv_file_path)
return f"Exported {len(data_collector.rows)} feeds to CSV file {csv_file_path}."
logging.info("Export started")

csv_file_path = csv_default_file_path
export_csv(csv_file_path)
upload_file_to_storage(csv_file_path, "sources_v2.csv")

logging.info("Export successful")
return "Export successful"


def collect_data() -> DataCollector:
def export_csv(csv_file_path: str):
"""
Collect data from the DB and write the output to a DataCollector.
:return: A filled DataCollector
Write feed data to a local CSV file.
"""
with open(csv_file_path, "w") as out:
writer = csv.DictWriter(out, fieldnames=headers)
writer.writeheader()

count = 0
for feed in fetch_feeds():
writer.writerow(feed)
count += 1

logging.info(f"Exported {count} feeds to CSV file {csv_file_path}.")


def fetch_feeds() -> Iterator[Dict]:
"""
Fetch and return feed data from the DB.
:return: Data to write to the output CSV file.
"""
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
logging.info(f"Using database {db.database_url}")
Expand All @@ -118,28 +132,19 @@ def collect_data() -> DataCollector:

logging.info(f"Retrieved {len(gtfs_rt_feeds)} GTFS realtime feeds.")

data_collector = DataCollector()

for feed in gtfs_feeds:
data = get_feed_csv_data(feed)
yield get_feed_csv_data(feed)

for key, value in data.items():
data_collector.add_data(key, value)
data_collector.finalize_row()
logging.info(f"Processed {len(gtfs_feeds)} GTFS feeds.")

for feed in gtfs_rt_feeds:
data = get_gtfs_rt_feed_csv_data(feed)
for key, value in data.items():
data_collector.add_data(key, value)
data_collector.finalize_row()
yield get_gtfs_rt_feed_csv_data(feed)

logging.info(f"Processed {len(gtfs_rt_feeds)} GTFS realtime feeds.")

except Exception as error:
logging.error(f"Error retrieving feeds: {error}")
raise Exception(f"Error retrieving feeds: {error}")
data_collector.write_csv_to_file(csv_file_path)
return data_collector


def extract_numeric_version(version):
Expand All @@ -166,12 +171,9 @@ def get_feed_csv_data(feed: Gtfsfeed):

if latest_dataset and latest_dataset.validation_reports:
# Keep the report from the more recent validator version
latest_report = reduce(
lambda a, b: a
if Version(extract_numeric_version(a.validator_version))
> Version(extract_numeric_version(b.validator_version))
else b,
latest_report = max(
latest_dataset.validation_reports,
key=lambda r: Version(extract_numeric_version(r.validator_version)),
)

if latest_report:
Expand Down Expand Up @@ -234,8 +236,8 @@ def get_feed_csv_data(feed: Gtfsfeed):
"location.bounding_box.maximum_latitude": maximum_latitude,
"location.bounding_box.minimum_longitude": minimum_longitude,
"location.bounding_box.maximum_longitude": maximum_longitude,
"location.bounding_box.extracted_on": validated_at,
# We use the report validated_at date as the extracted_on date
"location.bounding_box.extracted_on": validated_at,
"status": feed.status,
"features": joined_features,
}
Expand Down Expand Up @@ -348,13 +350,17 @@ def upload_file_to_storage(source_file_path, target_path):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export DB feed contents to csv.")
parser = argparse.ArgumentParser(
description="Export DB feed contents to csv.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--outpath", help="Path to the output csv file. Default is ./output.csv"
"--outpath",
default=csv_default_file_path,
help="Path to the output csv file.",
)
os.environ[
"FEEDS_DATABASE_URL"
] = "postgresql://postgres:postgres@localhost:54320/MobilityDatabaseTest"
args = parser.parse_args()
csv_file_path = args.outpath if args.outpath else csv_default_file_path
export_csv()
export_csv(args.outpath)
2 changes: 1 addition & 1 deletion functions-python/export_csv/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def populate_database():
feeds = []
# We create 3 feeds. The first one is active. The third one is inactive and redirected to the first one.
# The second one is active but not redirected.
# First fill the generic paramaters
# First fill the generic parameters
for i in range(3):
feed = Gtfsfeed(
data_type="gtfs",
Expand Down
25 changes: 6 additions & 19 deletions functions-python/export_csv/tests/test_export_csv_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,14 @@ def test_export_csv():
os.environ[
"FEEDS_DATABASE_URL"
] = "postgresql://postgres:postgres@localhost:54320/MobilityDatabaseTest"
data_collector = main.collect_data()
print(f"Collected data for {len(data_collector.rows)} feeds.")

df_extracted = data_collector.get_dataframe()
csv_file_path = "./output.csv"
main.export_csv(csv_file_path)
df_actual = pd.read_csv(csv_file_path)
print(f"Collected data for {len(df_actual)} feeds.")

csv_buffer = io.StringIO(expected_csv)
df_from_expected_csv = pd.read_csv(csv_buffer)
df_from_expected_csv.fillna("", inplace=True)

df_extracted.fillna("", inplace=True)

df_extracted["urls.authentication_type"] = df_extracted[
"urls.authentication_type"
].astype(str)
df_from_expected_csv["urls.authentication_type"] = df_from_expected_csv[
"urls.authentication_type"
].astype(str)
df_from_expected_csv["location.bounding_box.extracted_on"] = pd.to_datetime(
df_from_expected_csv["location.bounding_box.extracted_on"], utc=True
)
df_expected = pd.read_csv(io.StringIO(expected_csv))

# try:
pdt.assert_frame_equal(df_extracted, df_from_expected_csv)
pdt.assert_frame_equal(df_actual, df_expected)
print("DataFrames are equal.")
3 changes: 1 addition & 2 deletions scripts/api-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ execute_tests() {

# Generate coverage report
current_dir_name=$(basename "$(pwd)")
mkdir $ABS_SCRIPTPATH/coverage_reports
mkdir $ABS_SCRIPTPATH/coverage_reports/$current_dir_name
mkdir -p $ABS_SCRIPTPATH/coverage_reports/$current_dir_name
venv/bin/coverage report > $ABS_SCRIPTPATH/coverage_reports/$current_dir_name/report.txt
printf "\n${YELLOW}COVERAGE REPORT FOR $1:${NC}\n"
cat $ABS_SCRIPTPATH/coverage_reports/$current_dir_name/report.txt
Expand Down
Loading