Skip to content

Commit 9003687

Browse files
committed
test: added gbfs data and gbfs_feeds endpoint test
1 parent f6a41ae commit 9003687

File tree

14 files changed

+285
-219
lines changed

14 files changed

+285
-219
lines changed

api/src/feeds/impl/models/search_feed_item_result_impl.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,29 @@ def from_orm_gtfs(cls, feed_search_row):
5252
else None,
5353
)
5454

55+
@classmethod
56+
def from_orm_gbfs(cls, feed_search_row):
57+
"""Create a model instance from a SQLAlchemy a GTFS row object."""
58+
return cls(
59+
id=feed_search_row.feed_stable_id,
60+
data_type=feed_search_row.data_type,
61+
status=feed_search_row.status,
62+
external_ids=feed_search_row.external_ids,
63+
provider=feed_search_row.provider,
64+
feed_contact_email=feed_search_row.feed_contact_email,
65+
source_info=SourceInfo(
66+
producer_url=feed_search_row.producer_url,
67+
authentication_type=int(feed_search_row.authentication_type)
68+
if feed_search_row.authentication_type
69+
else None,
70+
authentication_info_url=feed_search_row.authentication_info_url,
71+
api_key_parameter_name=feed_search_row.api_key_parameter_name,
72+
license_url=feed_search_row.license_url,
73+
),
74+
redirects=feed_search_row.redirect_ids,
75+
locations=cls.resolve_locations(feed_search_row.locations),
76+
)
77+
5578
@classmethod
5679
def from_orm_gtfs_rt(cls, feed_search_row):
5780
"""Create a model instance from a SQLAlchemy a GTFS-RT row object."""
@@ -115,5 +138,7 @@ def from_orm(cls, feed_search_row):
115138
return cls.from_orm_gtfs(feed_search_row)
116139
case "gtfs_rt":
117140
return cls.from_orm_gtfs_rt(feed_search_row)
141+
case "gbfs":
142+
return cls.from_orm_gbfs(feed_search_row)
118143
case _:
119144
raise ValueError(f"Unknown data type: {feed_search_row.data_type}")

api/src/feeds/impl/search_api_impl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status, is
4646
if feed_id:
4747
query = query.where(t_feedsearch.c.feed_stable_id == feed_id.strip().lower())
4848
if data_type:
49-
query = query.where(t_feedsearch.c.data_type == data_type.strip().lower())
49+
data_types = [dt.strip().lower() for dt in data_type.split(",")]
50+
if data_types:
51+
query = query.where(t_feedsearch.c.data_type.in_(data_types))
5052
if status:
5153
status_list = [s.strip().lower() for s in status[0].split(",") if s]
5254
if status_list:

api/src/scripts/gbfs_utils/comparison.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def generate_system_csv_from_db(df, db_session):
2929
"Supported Versions": " ; ".join(supported_versions),
3030
}
3131
)
32+
if not data:
33+
# Return an empty DataFrame with the same columns
34+
return pd.DataFrame(columns=df.columns)
3235
return pd.DataFrame(data)
3336

3437

api/src/scripts/populate_db.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
import argparse
22
import logging
33
import os
4+
import traceback
45
from pathlib import Path
56
from typing import Type, TYPE_CHECKING
67

78
import pandas
89
from dotenv import load_dotenv
10+
from sqlalchemy import text
911

12+
from shared.common.logging_utils import Logger
1013
from shared.database.database import Database
14+
from shared.database.database import configure_polymorphic_mappers
1115
from shared.database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed
12-
from shared.common.logging_utils import Logger
16+
from shared.database_gen.sqlacodegen_models import (
17+
t_feedsearch,
18+
)
1319

1420
if TYPE_CHECKING:
1521
from sqlalchemy.orm import Session
@@ -117,3 +123,35 @@ def filter_data(self):
117123
Filter the data to only include the necessary columns
118124
"""
119125
pass # Should be implemented in the child class
126+
127+
def populate_db(self, session: "Session", fetch_url: bool = True):
128+
"""
129+
Populate the database with the data
130+
"""
131+
pass # Should be implemented in the child class
132+
133+
def trigger_downstream_tasks(self):
134+
"""
135+
Trigger downstream tasks
136+
"""
137+
pass # Should be implemented in the child class
138+
139+
# Extracted the following code from main, so it can be executed as a library function
140+
def initialize(self, trigger_downstream_tasks: bool = True, fetch_url: bool = True):
141+
try:
142+
configure_polymorphic_mappers()
143+
with self.db.start_db_session() as session:
144+
self.populate_db(session, fetch_url=fetch_url)
145+
session.commit()
146+
147+
self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Started")
148+
session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {t_feedsearch.name}"))
149+
self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Completed")
150+
session.commit()
151+
self.logger.info("\n----- Database populated with sources.csv data. -----")
152+
if trigger_downstream_tasks:
153+
self.trigger_downstream_tasks()
154+
except Exception as e:
155+
self.logger.error(f"\n------ Failed to populate the database with sources.csv: {e} -----\n")
156+
traceback.print_exc()
157+
exit(1)

api/src/scripts/populate_db_gbfs.py

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414

1515
class GBFSDatabasePopulateHelper(DatabasePopulateHelper):
16-
def __init__(self, file_path):
17-
super().__init__(file_path)
16+
def __init__(self, filepaths):
17+
super().__init__(filepaths)
1818

1919
def filter_data(self):
2020
"""Filter out rows with Authentication Info and duplicate System IDs"""
21-
self.df = self.df[pd.isna(self.df["Authentication Info"])]
21+
self.df = self.df[pd.isna(self.df["Authentication Info URL"])]
2222
self.df = self.df[~self.df.duplicated(subset="System ID", keep=False)]
2323
self.logger.info(f"Data = {self.df}")
2424

@@ -45,77 +45,80 @@ def deprecate_feeds(self, deprecated_feeds):
4545
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
4646
gbfs_feed.status = "deprecated"
4747

48-
def populate_db(self):
48+
def populate_db(self, session, fetch_url=True):
4949
"""Populate the database with the GBFS feeds"""
5050
start_time = datetime.now()
5151
configure_polymorphic_mappers()
5252

5353
try:
54-
with self.db.start_db_session() as session:
55-
# Compare the database to the CSV file
56-
df_from_db = generate_system_csv_from_db(self.df, session)
57-
added_or_updated_feeds, deprecated_feeds = compare_db_to_csv(df_from_db, self.df, self.logger)
58-
59-
self.deprecate_feeds(deprecated_feeds)
60-
if added_or_updated_feeds is None:
61-
added_or_updated_feeds = self.df
62-
for index, row in added_or_updated_feeds.iterrows():
63-
self.logger.info(f"Processing row {index + 1} of {len(added_or_updated_feeds)}")
64-
stable_id = self.get_stable_id(row)
65-
gbfs_feed = self.query_feed_by_stable_id(session, stable_id, "gbfs")
54+
# Compare the database to the CSV file
55+
df_from_db = generate_system_csv_from_db(self.df, session)
56+
added_or_updated_feeds, deprecated_feeds = compare_db_to_csv(df_from_db, self.df, self.logger)
57+
58+
self.deprecate_feeds(deprecated_feeds)
59+
if added_or_updated_feeds is None:
60+
added_or_updated_feeds = self.df
61+
for index, row in added_or_updated_feeds.iterrows():
62+
self.logger.info(f"Processing row {index + 1} of {len(added_or_updated_feeds)}")
63+
stable_id = self.get_stable_id(row)
64+
gbfs_feed = self.query_feed_by_stable_id(session, stable_id, "gbfs")
65+
if fetch_url:
6666
fetched_data = fetch_data(row["Auto-Discovery URL"], self.logger, ["system_information"])
67-
# If the feed already exists, update it. Otherwise, create a new feed.
68-
if gbfs_feed:
69-
self.logger.info(f"Updating feed {stable_id} - {row['Name']}")
70-
else:
71-
feed_id = generate_unique_id()
72-
self.logger.info(f"Creating new feed for {stable_id} - {row['Name']}")
73-
gbfs_feed = Gbfsfeed(
74-
id=feed_id,
75-
data_type="gbfs",
76-
stable_id=stable_id,
77-
created_at=datetime.now(pytz.utc),
78-
operational_status="published",
79-
)
80-
gbfs_feed.externalids = [self.get_external_id(feed_id, row["System ID"])]
81-
session.add(gbfs_feed)
82-
83-
system_information_content = get_data_content(fetched_data.get("system_information"), self.logger)
84-
gbfs_feed.license_url = get_license_url(system_information_content, self.logger)
85-
gbfs_feed.feed_contact_email = (
86-
system_information_content.get("feed_contact_email") if system_information_content else None
67+
else:
68+
fetched_data = dict()
69+
# If the feed already exists, update it. Otherwise, create a new feed.
70+
if gbfs_feed:
71+
self.logger.info(f"Updating feed {stable_id} - {row['Name']}")
72+
else:
73+
feed_id = generate_unique_id()
74+
self.logger.info(f"Creating new feed for {stable_id} - {row['Name']}")
75+
gbfs_feed = Gbfsfeed(
76+
id=feed_id,
77+
data_type="gbfs",
78+
stable_id=stable_id,
79+
created_at=datetime.now(pytz.utc),
80+
operational_status="published",
8781
)
88-
gbfs_feed.system_id = str(row["System ID"]).strip()
89-
gbfs_feed.operator = row["Name"]
90-
gbfs_feed.provider = row["Name"]
91-
gbfs_feed.operator_url = row["URL"]
92-
gbfs_feed.producer_url = row["URL"]
93-
gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"]
94-
gbfs_feed.updated_at = datetime.now(pytz.utc)
95-
96-
if not gbfs_feed.locations: # If locations are empty, create a new location (no overwrite)
97-
country_code = self.get_safe_value(row, "Country Code", "")
98-
municipality = self.get_safe_value(row, "Location", "")
99-
location_id = self.get_location_id(country_code, None, municipality)
100-
country = pycountry.countries.get(alpha_2=country_code) if country_code else None
101-
location = session.get(Location, location_id) or Location(
102-
id=location_id,
103-
country_code=country_code,
104-
country=country.name if country else None,
105-
municipality=municipality,
106-
)
107-
gbfs_feed.locations.clear()
108-
gbfs_feed.locations = [location]
109-
110-
self.logger.info(80 * "-")
111-
112-
# self.db.session.commit()
113-
end_time = datetime.now()
114-
self.logger.info(f"Time taken: {end_time - start_time} seconds")
82+
gbfs_feed.externalids = [self.get_external_id(feed_id, row["System ID"])]
83+
session.add(gbfs_feed)
84+
85+
system_information_content = get_data_content(fetched_data.get("system_information"), self.logger)
86+
gbfs_feed.license_url = get_license_url(system_information_content, self.logger)
87+
gbfs_feed.feed_contact_email = (
88+
system_information_content.get("feed_contact_email") if system_information_content else None
89+
)
90+
gbfs_feed.system_id = str(row["System ID"]).strip()
91+
gbfs_feed.operator = row["Name"]
92+
gbfs_feed.provider = row["Name"]
93+
gbfs_feed.operator_url = row["URL"]
94+
gbfs_feed.producer_url = row["URL"]
95+
gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"]
96+
gbfs_feed.updated_at = datetime.now(pytz.utc)
97+
98+
if not gbfs_feed.locations: # If locations are empty, create a new location (no overwrite)
99+
country_code = self.get_safe_value(row, "Country Code", "")
100+
municipality = self.get_safe_value(row, "Location", "")
101+
location_id = self.get_location_id(country_code, None, municipality)
102+
country = pycountry.countries.get(alpha_2=country_code) if country_code else None
103+
location = session.get(Location, location_id) or Location(
104+
id=location_id,
105+
country_code=country_code,
106+
country=country.name if country else None,
107+
municipality=municipality,
108+
)
109+
gbfs_feed.locations.clear()
110+
gbfs_feed.locations = [location]
111+
112+
session.flush()
113+
self.logger.info(80 * "-")
114+
115+
# self.db.session.commit()
116+
end_time = datetime.now()
117+
self.logger.info(f"Time taken: {end_time - start_time} seconds")
115118
except Exception as e:
116119
self.logger.error(f"Error populating the database: {e}")
117120
raise e
118121

119122

120123
if __name__ == "__main__":
121-
GBFSDatabasePopulateHelper(set_up_configs()).populate_db()
124+
GBFSDatabasePopulateHelper(set_up_configs()).initialize(trigger_downstream_tasks=False)

api/src/scripts/populate_db_gtfs.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
11
import os
2-
import traceback
32
from datetime import datetime
43
from typing import TYPE_CHECKING
54

65
import pycountry
76
import pytz
8-
from sqlalchemy import text
97

108
from scripts.load_dataset_on_create import publish_all
119
from scripts.populate_db import DatabasePopulateHelper, set_up_configs
12-
from shared.database.database import generate_unique_id, configure_polymorphic_mappers
10+
from shared.database.database import generate_unique_id
1311
from shared.database_gen.sqlacodegen_models import (
1412
Entitytype,
1513
Externalid,
1614
Gtfsrealtimefeed,
1715
Location,
1816
Redirectingid,
19-
t_feedsearch,
2017
)
2118
from utils.data_utils import set_up_defaults
2219

@@ -191,7 +188,7 @@ def process_redirects(self, session: "Session"):
191188
# Flush to avoid FK violation
192189
session.flush()
193190

194-
def populate_db(self, session: "Session"):
191+
def populate_db(self, session: "Session", fetch_url: bool = True):
195192
"""
196193
Populate the database with the sources.csv data
197194
"""
@@ -286,26 +283,6 @@ def post_process_locations(self, session: "Session"):
286283
session.commit()
287284
self.logger.info(f"Had to set the country for {set_country_count} locations")
288285

289-
# Extracted the following code from main, so it can be executed as a library function
290-
def initialize(self, trigger_downstream_tasks: bool = True):
291-
try:
292-
configure_polymorphic_mappers()
293-
with self.db.start_db_session() as session:
294-
self.populate_db(session)
295-
session.commit()
296-
297-
self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Started")
298-
session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {t_feedsearch.name}"))
299-
self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Completed")
300-
session.commit()
301-
self.logger.info("\n----- Database populated with sources.csv data. -----")
302-
if trigger_downstream_tasks:
303-
self.trigger_downstream_tasks()
304-
except Exception as e:
305-
self.logger.error(f"\n------ Failed to populate the database with sources.csv: {e} -----\n")
306-
traceback.print_exc()
307-
exit(1)
308-
309286

310287
if __name__ == "__main__":
311288
db_helper = GTFSDatabasePopulateHelper(set_up_configs())

api/src/scripts/populate_db_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def populate_test_datasets(self, filepath, db_session: "Session"):
4747
"""
4848
Populate the database with the test datasets
4949
"""
50+
# TODO: parse GBFS versions
5051
# Load the JSON file
5152
with open(filepath) as f:
5253
data = json.load(f)

0 commit comments

Comments
 (0)