Skip to content

Commit 92f4547

Browse files
authored
Merge pull request #862 from MobilityData/psycopg2-fix
fix: Psycopg2 fix
2 parents 2aec616 + cf2a50a commit 92f4547

File tree

2 files changed

+83
-75
lines changed

2 files changed

+83
-75
lines changed

api/src/scripts/populate_db_gbfs.py

Lines changed: 82 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -45,88 +45,96 @@ def deprecate_feeds(self, deprecated_feeds):
4545
if gbfs_feed:
4646
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
4747
gbfs_feed.status = "deprecated"
48-
session.flush()
48+
# session.flush()
4949

5050
def populate_db(self):
5151
"""Populate the database with the GBFS feeds"""
5252
start_time = datetime.now()
5353
configure_polymorphic_mappers()
5454

55-
# Compare the database to the CSV file
56-
df_from_db = generate_system_csv_from_db(self.df, self.db.session)
57-
added_or_updated_feeds, deprecated_feeds = compare_db_to_csv(df_from_db, self.df, self.logger)
58-
59-
self.deprecate_feeds(deprecated_feeds)
60-
if added_or_updated_feeds is None:
61-
added_or_updated_feeds = self.df
62-
for index, row in added_or_updated_feeds.iterrows():
63-
self.logger.info(f"Processing row {index + 1} of {len(added_or_updated_feeds)}")
64-
stable_id = self.get_stable_id(row)
65-
gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs")
66-
fetched_data = fetch_data(
67-
row["Auto-Discovery URL"], self.logger, ["system_information", "gbfs_versions"], ["version"]
68-
)
69-
# If the feed already exists, update it. Otherwise, create a new feed.
70-
if gbfs_feed:
71-
feed_id = gbfs_feed.id
72-
self.logger.info(f"Updating feed {stable_id} - {row['Name']}")
73-
else:
74-
feed_id = generate_unique_id()
75-
self.logger.info(f"Creating new feed for {stable_id} - {row['Name']}")
76-
gbfs_feed = Gbfsfeed(
77-
id=feed_id,
78-
data_type="gbfs",
79-
stable_id=stable_id,
80-
created_at=datetime.now(pytz.utc),
81-
)
82-
gbfs_feed.externalids = [self.get_external_id(feed_id, row["System ID"])]
83-
self.db.session.add(gbfs_feed)
84-
85-
system_information_content = get_data_content(fetched_data.get("system_information"), self.logger)
86-
gbfs_feed.license_url = get_license_url(system_information_content, self.logger)
87-
gbfs_feed.feed_contact_email = (
88-
system_information_content.get("feed_contact_email") if system_information_content else None
89-
)
90-
gbfs_feed.operator = row["Name"]
91-
gbfs_feed.operator_url = row["URL"]
92-
gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"]
93-
gbfs_feed.updated_at = datetime.now(pytz.utc)
94-
95-
country_code = self.get_safe_value(row, "Country Code", "")
96-
municipality = self.get_safe_value(row, "Location", "")
97-
location_id = self.get_location_id(country_code, None, municipality)
98-
country = pycountry.countries.get(alpha_2=country_code) if country_code else None
99-
location = self.db.session.get(Location, location_id) or Location(
100-
id=location_id,
101-
country_code=country_code,
102-
country=country.name if country else None,
103-
municipality=municipality,
104-
)
105-
gbfs_feed.locations.clear()
106-
gbfs_feed.locations = [location]
107-
108-
# Add the GBFS versions
109-
versions = get_gbfs_versions(
110-
fetched_data.get("gbfs_versions"), row["Auto-Discovery URL"], fetched_data.get("version"), self.logger
111-
)
112-
existing_versions = [version.version for version in gbfs_feed.gbfsversions]
113-
for version in versions:
114-
version_value = version.get("version")
115-
if version_value.upper() in OFFICIAL_VERSIONS and version_value not in existing_versions:
116-
gbfs_feed.gbfsversions.append(
117-
Gbfsversion(
118-
feed_id=feed_id,
119-
url=version.get("url"),
120-
version=version_value,
121-
)
55+
try:
56+
with self.db.start_db_session() as session:
57+
# Compare the database to the CSV file
58+
df_from_db = generate_system_csv_from_db(self.df, session)
59+
added_or_updated_feeds, deprecated_feeds = compare_db_to_csv(df_from_db, self.df, self.logger)
60+
61+
self.deprecate_feeds(deprecated_feeds)
62+
if added_or_updated_feeds is None:
63+
added_or_updated_feeds = self.df
64+
for index, row in added_or_updated_feeds.iterrows():
65+
self.logger.info(f"Processing row {index + 1} of {len(added_or_updated_feeds)}")
66+
stable_id = self.get_stable_id(row)
67+
gbfs_feed = self.query_feed_by_stable_id(session, stable_id, "gbfs")
68+
fetched_data = fetch_data(
69+
row["Auto-Discovery URL"], self.logger, ["system_information", "gbfs_versions"], ["version"]
12270
)
71+
# If the feed already exists, update it. Otherwise, create a new feed.
72+
if gbfs_feed:
73+
feed_id = gbfs_feed.id
74+
self.logger.info(f"Updating feed {stable_id} - {row['Name']}")
75+
else:
76+
feed_id = generate_unique_id()
77+
self.logger.info(f"Creating new feed for {stable_id} - {row['Name']}")
78+
gbfs_feed = Gbfsfeed(
79+
id=feed_id,
80+
data_type="gbfs",
81+
stable_id=stable_id,
82+
created_at=datetime.now(pytz.utc),
83+
)
84+
gbfs_feed.externalids = [self.get_external_id(feed_id, row["System ID"])]
85+
session.add(gbfs_feed)
12386

124-
self.db.session.flush()
125-
self.logger.info(80 * "-")
126-
127-
self.db.session.commit()
128-
end_time = datetime.now()
129-
self.logger.info(f"Time taken: {end_time - start_time} seconds")
87+
system_information_content = get_data_content(fetched_data.get("system_information"), self.logger)
88+
gbfs_feed.license_url = get_license_url(system_information_content, self.logger)
89+
gbfs_feed.feed_contact_email = (
90+
system_information_content.get("feed_contact_email") if system_information_content else None
91+
)
92+
gbfs_feed.operator = row["Name"]
93+
gbfs_feed.operator_url = row["URL"]
94+
gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"]
95+
gbfs_feed.updated_at = datetime.now(pytz.utc)
96+
97+
country_code = self.get_safe_value(row, "Country Code", "")
98+
municipality = self.get_safe_value(row, "Location", "")
99+
location_id = self.get_location_id(country_code, None, municipality)
100+
country = pycountry.countries.get(alpha_2=country_code) if country_code else None
101+
location = session.get(Location, location_id) or Location(
102+
id=location_id,
103+
country_code=country_code,
104+
country=country.name if country else None,
105+
municipality=municipality,
106+
)
107+
gbfs_feed.locations.clear()
108+
gbfs_feed.locations = [location]
109+
110+
# Add the GBFS versions
111+
versions = get_gbfs_versions(
112+
fetched_data.get("gbfs_versions"),
113+
row["Auto-Discovery URL"],
114+
fetched_data.get("version"),
115+
self.logger,
116+
)
117+
existing_versions = [version.version for version in gbfs_feed.gbfsversions]
118+
for version in versions:
119+
version_value = version.get("version")
120+
if version_value.upper() in OFFICIAL_VERSIONS and version_value not in existing_versions:
121+
gbfs_feed.gbfsversions.append(
122+
Gbfsversion(
123+
feed_id=feed_id,
124+
url=version.get("url"),
125+
version=version_value,
126+
)
127+
)
128+
129+
# self.db.session.flush()
130+
self.logger.info(80 * "-")
131+
132+
# self.db.session.commit()
133+
end_time = datetime.now()
134+
self.logger.info(f"Time taken: {end_time - start_time} seconds")
135+
except Exception as e:
136+
self.logger.error(f"Error populating the database: {e}")
137+
raise e
130138

131139

132140
if __name__ == "__main__":

infra/postgresql/main.tf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ resource "google_secret_manager_secret" "secret_db_url" {
106106

107107
resource "google_secret_manager_secret_version" "secret_version" {
108108
secret = google_secret_manager_secret.secret_db_url.id
109-
secret_data = "postgresql:+psycopg2//${var.postgresql_user_name}:${var.postgresql_user_password}@${google_sql_database_instance.db.private_ip_address}/${var.postgresql_database_name}"
109+
secret_data = "postgresql+psycopg2://${var.postgresql_user_name}:${var.postgresql_user_password}@${google_sql_database_instance.db.private_ip_address}/${var.postgresql_database_name}"
110110
}
111111

112112
output "instance_address" {

0 commit comments

Comments
 (0)