1414# limitations under the License.
1515#
1616import argparse
17- import pandas as pd
1817import os
18+ import re
19+
20+ import pandas as pd
1921
2022from dotenv import load_dotenv
2123import functions_framework
2729
2830from shared .database_gen .sqlacodegen_models import Gtfsfeed , Gtfsrealtimefeed
2931from collections import OrderedDict
30- from shared .common .db_utils import get_gtfs_feeds_query , get_gtfs_rt_feeds_query
32+ from shared .common .db_utils import get_all_gtfs_rt_feeds_query , get_all_gtfs_feeds_query
3133
3234from shared .helpers .database import Database
3335
@@ -57,10 +59,13 @@ def finalize_row(self):
5759 self .rows .append (self .data .copy ())
5860 self .data = OrderedDict ()
5961
60- def write_csv (self , csv_file_path ):
62+ def write_csv_to_file (self , csv_file_path ):
6163 df = pd .DataFrame (self .rows , columns = self .headers )
6264 df .to_csv (csv_file_path , index = False )
6365
66+ def get_dataframe (self ) -> pd :
67+ return pd .DataFrame (self .rows , columns = self .headers )
68+
6469
6570@functions_framework .http
6671def export_csv (request = None ):
@@ -71,21 +76,20 @@ def export_csv(request=None):
7176 :param request: HTTP request object
7277 :return: HTTP response object
7378 """
79+ data_collector = collect_data ()
80+ data_collector .write_csv_to_file (csv_file_path )
81+ return f"Export of database feeds to CSV file { csv_file_path } ."
82+
83+
84+ def collect_data () -> DataCollector :
85+ """
86+ Collect data from the DB and write the output to a DataCollector.
87+ :return: A filled DataCollector
88+ """
7489 db = Database (database_url = os .getenv ("FEEDS_DATABASE_URL" ))
7590 try :
7691 with db .start_db_session () as session :
77- gtfs_feeds_query = get_gtfs_feeds_query (
78- limit = None ,
79- offset = 0 ,
80- provider = None ,
81- producer_url = None ,
82- country_code = None ,
83- subdivision_name = None ,
84- municipality = None ,
85- dataset_latitudes = None ,
86- dataset_longitudes = None ,
87- bounding_filter_method = None ,
88- is_official = None ,
92+ gtfs_feeds_query = get_all_gtfs_feeds_query (
8993 include_wip = False ,
9094 db_session = session ,
9195 )
@@ -94,16 +98,7 @@ def export_csv(request=None):
9498
9599 print (f"Retrieved { len (gtfs_feeds )} GTFS feeds." )
96100
97- gtfs_rt_feeds_query = get_gtfs_rt_feeds_query (
98- limit = None ,
99- offset = 0 ,
100- provider = None ,
101- producer_url = None ,
102- entity_types = None ,
103- country_code = None ,
104- subdivision_name = None ,
105- municipality = None ,
106- is_official = None ,
101+ gtfs_rt_feeds_query = get_all_gtfs_rt_feeds_query (
107102 include_wip = False ,
108103 db_session = session ,
109104 )
@@ -134,11 +129,13 @@ def export_csv(request=None):
134129 except Exception as error :
135130 print (f"Error retrieving feeds: { error } " )
136131 raise Exception (f"Error retrieving feeds: { error } " )
132+ data_collector .write_csv_to_file (csv_file_path )
133+ return data_collector
137134
138- data_collector .write_csv (csv_file_path )
139135
140- print (f"Wrote { len (gtfs_feeds )} feeds to { csv_file_path } ." )
141- return f"Wrote { len (gtfs_feeds )} feeds to { csv_file_path } ."
136+ def extract_numeric_version (version ):
137+ match = re .match (r"(\d+\.\d+\.\d+)" , version )
138+ return match .group (1 ) if match else version
142139
143140
144141def get_feed_csv_data (feed : Gtfsfeed ):
@@ -162,15 +159,19 @@ def get_feed_csv_data(feed: Gtfsfeed):
162159 # Keep the report from the more recent validator version
163160 latest_report = reduce (
164161 lambda a , b : a
165- if Version (a .validator_version ) > Version (b .validator_version )
162+ if Version (extract_numeric_version (a .validator_version ))
163+ > Version (extract_numeric_version (b .validator_version ))
166164 else b ,
167165 latest_dataset .validation_reports ,
168166 )
167+
169168 if latest_report :
170169 if latest_report .features :
171170 features = latest_report .features
172171 joined_features = (
173- "|" .join (feature .name for feature in features if feature .name )
172+ "|" .join (
173+ sorted (feature .name for feature in features if feature .name )
174+ )
174175 if features
175176 else ""
176177 )
@@ -185,7 +186,7 @@ def get_feed_csv_data(feed: Gtfsfeed):
185186 maximum_longitude = shape .bounds [2 ]
186187
187188 data = {
188- "mdb_source_id " : feed .stable_id ,
189+ "id " : feed .stable_id ,
189190 "data_type" : feed .data_type ,
190191 "entity_type" : None ,
191192 "location.country_code" : ""
@@ -262,6 +263,7 @@ def get_gtfs_rt_feed_csv_data(feed: Gtfsrealtimefeed):
262263 for entity_type in feed .entitytypes
263264 if entity_type and entity_type .name
264265 ]
266+ valid_entity_types = sorted (valid_entity_types )
265267 entity_types = "|" .join (valid_entity_types )
266268
267269 static_references = ""
@@ -274,7 +276,7 @@ def get_gtfs_rt_feed_csv_data(feed: Gtfsrealtimefeed):
274276 static_references = "|" .join (valid_feed_references )
275277
276278 data = {
277- "mdb_source_id " : feed .stable_id ,
279+ "id " : feed .stable_id ,
278280 "data_type" : feed .data_type ,
279281 "entity_type" : entity_types ,
280282 "location.country_code" : ""
0 commit comments