1414# limitations under the License.
1515#
1616import argparse
17+ import csv
1718import logging
1819import os
1920import re
20-
21- import pandas as pd
21+ from typing import Dict , Iterator
2222
2323from dotenv import load_dotenv
2424import functions_framework
2525
2626from packaging .version import Version
27- from functools import reduce
2827from google .cloud import storage
2928from geoalchemy2 .shape import to_shape
3029
3130from shared .helpers .logger import Logger
3231from shared .database_gen .sqlacodegen_models import Gtfsfeed , Gtfsrealtimefeed
33- from collections import OrderedDict
3432from shared .common .db_utils import get_all_gtfs_rt_feeds_query , get_all_gtfs_feeds_query
3533
3634from shared .helpers .database import Database
3735
3836load_dotenv ()
3937csv_default_file_path = "./output.csv"
40- csv_file_path = csv_default_file_path
41-
42-
43- class DataCollector :
44- """
45- A class used to collect and organize data into rows and headers for CSV output.
46- One particularity of this class is that it uses an OrderedDict to store the data, so that the order of the columns
47- is preserved when writing to CSV.
48- """
49-
50- def __init__ (self ):
51- self .data = OrderedDict ()
52- self .rows = []
53- self .headers = []
54-
55- def add_data (self , key , value ):
56- if key not in self .headers :
57- self .headers .append (key )
58- self .data [key ] = value
5938
60- def finalize_row (self ):
61- self .rows .append (self .data .copy ())
62- self .data = OrderedDict ()
63-
64- def write_csv_to_file (self , csv_file_path ):
65- df = pd .DataFrame (self .rows , columns = self .headers )
66- df .to_csv (csv_file_path , index = False )
67-
68- def get_dataframe (self ) -> pd :
69- return pd .DataFrame (self .rows , columns = self .headers )
39+ # This needs to be updated if we add fields to either `get_feed_csv_data` or
40+ # `get_gtfs_rt_feed_csv_data`, otherwise the extra field(s) will be excluded from
41+ # the generated CSV file.
42+ headers = [
43+ "id" ,
44+ "data_type" ,
45+ "entity_type" ,
46+ "location.country_code" ,
47+ "location.subdivision_name" ,
48+ "location.municipality" ,
49+ "provider" ,
50+ "name" ,
51+ "note" ,
52+ "feed_contact_email" ,
53+ "static_reference" ,
54+ "urls.direct_download" ,
55+ "urls.authentication_type" ,
56+ "urls.authentication_info" ,
57+ "urls.api_key_parameter_name" ,
58+ "urls.latest" ,
59+ "urls.license" ,
60+ "location.bounding_box.minimum_latitude" ,
61+ "location.bounding_box.maximum_latitude" ,
62+ "location.bounding_box.minimum_longitude" ,
63+ "location.bounding_box.maximum_longitude" ,
64+ "location.bounding_box.extracted_on" ,
65+ "status" ,
66+ "features" ,
67+ "redirect.id" ,
68+ "redirect.comment" ,
69+ ]
7070
7171
7272@functions_framework .http
7373def export_and_upload_csv (request = None ):
74- response = export_csv ()
75- upload_file_to_storage (csv_file_path , "sources_v2.csv" )
76- return response
77-
78-
79- def export_csv ():
8074 """
8175 HTTP Function entry point Reads the DB and outputs a csv file with feeds data.
8276 This function requires the following environment variables to be set:
@@ -85,16 +79,36 @@ def export_csv():
8579 :return: HTTP response object
8680 """
8781 Logger .init_logger ()
88- logging .info ("Function Started" )
89- data_collector = collect_data ()
90- data_collector .write_csv_to_file (csv_file_path )
91- return f"Exported { len (data_collector .rows )} feeds to CSV file { csv_file_path } ."
82+ logging .info ("Export started" )
83+
84+ csv_file_path = csv_default_file_path
85+ export_csv (csv_file_path )
86+ upload_file_to_storage (csv_file_path , "sources_v2.csv" )
87+
88+ logging .info ("Export successful" )
89+ return "Export successful"
9290
9391
94- def collect_data () -> DataCollector :
92+ def export_csv ( csv_file_path : str ) :
9593 """
96- Collect data from the DB and write the output to a DataCollector.
97- :return: A filled DataCollector
94+ Write feed data to a local CSV file.
95+ """
96+ with open (csv_file_path , "w" ) as out :
97+ writer = csv .DictWriter (out , fieldnames = headers )
98+ writer .writeheader ()
99+
100+ count = 0
101+ for feed in fetch_feeds ():
102+ writer .writerow (feed )
103+ count += 1
104+
105+ logging .info (f"Exported { count } feeds to CSV file { csv_file_path } ." )
106+
107+
108+ def fetch_feeds () -> Iterator [Dict ]:
109+ """
110+ Fetch and return feed data from the DB.
111+ :return: Data to write to the output CSV file.
98112 """
99113 db = Database (database_url = os .getenv ("FEEDS_DATABASE_URL" ))
100114 logging .info (f"Using database { db .database_url } " )
@@ -118,28 +132,19 @@ def collect_data() -> DataCollector:
118132
119133 logging .info (f"Retrieved { len (gtfs_rt_feeds )} GTFS realtime feeds." )
120134
121- data_collector = DataCollector ()
122-
123135 for feed in gtfs_feeds :
124- data = get_feed_csv_data (feed )
136+ yield get_feed_csv_data (feed )
125137
126- for key , value in data .items ():
127- data_collector .add_data (key , value )
128- data_collector .finalize_row ()
129138 logging .info (f"Processed { len (gtfs_feeds )} GTFS feeds." )
130139
131140 for feed in gtfs_rt_feeds :
132- data = get_gtfs_rt_feed_csv_data (feed )
133- for key , value in data .items ():
134- data_collector .add_data (key , value )
135- data_collector .finalize_row ()
141+ yield get_gtfs_rt_feed_csv_data (feed )
142+
136143 logging .info (f"Processed { len (gtfs_rt_feeds )} GTFS realtime feeds." )
137144
138145 except Exception as error :
139146 logging .error (f"Error retrieving feeds: { error } " )
140147 raise Exception (f"Error retrieving feeds: { error } " )
141- data_collector .write_csv_to_file (csv_file_path )
142- return data_collector
143148
144149
145150def extract_numeric_version (version ):
@@ -166,12 +171,9 @@ def get_feed_csv_data(feed: Gtfsfeed):
166171
167172 if latest_dataset and latest_dataset .validation_reports :
168173 # Keep the report from the more recent validator version
169- latest_report = reduce (
170- lambda a , b : a
171- if Version (extract_numeric_version (a .validator_version ))
172- > Version (extract_numeric_version (b .validator_version ))
173- else b ,
174+ latest_report = max (
174175 latest_dataset .validation_reports ,
176+ key = lambda r : Version (extract_numeric_version (r .validator_version )),
175177 )
176178
177179 if latest_report :
@@ -234,8 +236,8 @@ def get_feed_csv_data(feed: Gtfsfeed):
234236 "location.bounding_box.maximum_latitude" : maximum_latitude ,
235237 "location.bounding_box.minimum_longitude" : minimum_longitude ,
236238 "location.bounding_box.maximum_longitude" : maximum_longitude ,
237- "location.bounding_box.extracted_on" : validated_at ,
238239 # We use the report validated_at date as the extracted_on date
240+ "location.bounding_box.extracted_on" : validated_at ,
239241 "status" : feed .status ,
240242 "features" : joined_features ,
241243 }
@@ -348,13 +350,17 @@ def upload_file_to_storage(source_file_path, target_path):
348350
349351
350352if __name__ == "__main__" :
351- parser = argparse .ArgumentParser (description = "Export DB feed contents to csv." )
353+ parser = argparse .ArgumentParser (
354+ description = "Export DB feed contents to csv." ,
355+ formatter_class = argparse .ArgumentDefaultsHelpFormatter ,
356+ )
352357 parser .add_argument (
353- "--outpath" , help = "Path to the output csv file. Default is ./output.csv"
358+ "--outpath" ,
359+ default = csv_default_file_path ,
360+ help = "Path to the output csv file." ,
354361 )
355362 os .environ [
356363 "FEEDS_DATABASE_URL"
357364 ] = "postgresql://postgres:postgres@localhost:54320/MobilityDatabaseTest"
358365 args = parser .parse_args ()
359- csv_file_path = args .outpath if args .outpath else csv_default_file_path
360- export_csv ()
366+ export_csv (args .outpath )
0 commit comments