Skip to content

Commit cf2a50a

Browse files
committed
provided session using with statement
1 parent d982f3d commit cf2a50a

File tree

1 file changed

+76
-73
lines changed

1 file changed

+76
-73
lines changed

api/src/scripts/populate_db_gbfs.py

Lines changed: 76 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def deprecate_feeds(self, deprecated_feeds):
4545
if gbfs_feed:
4646
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
4747
gbfs_feed.status = "deprecated"
48-
session.flush()
48+
# session.flush()
4949

5050
def populate_db(self):
5151
"""Populate the database with the GBFS feeds"""
@@ -57,81 +57,84 @@ def populate_db(self):
5757
# Compare the database to the CSV file
5858
df_from_db = generate_system_csv_from_db(self.df, session)
5959
added_or_updated_feeds, deprecated_feeds = compare_db_to_csv(df_from_db, self.df, self.logger)
60-
except Exception as e:
61-
self.logger.error(f"Failed to compare the database to the CSV file. Error: {e}")
62-
return
6360

64-
self.deprecate_feeds(deprecated_feeds)
65-
if added_or_updated_feeds is None:
66-
added_or_updated_feeds = self.df
67-
for index, row in added_or_updated_feeds.iterrows():
68-
self.logger.info(f"Processing row {index + 1} of {len(added_or_updated_feeds)}")
69-
stable_id = self.get_stable_id(row)
70-
gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs")
71-
fetched_data = fetch_data(
72-
row["Auto-Discovery URL"], self.logger, ["system_information", "gbfs_versions"], ["version"]
73-
)
74-
# If the feed already exists, update it. Otherwise, create a new feed.
75-
if gbfs_feed:
76-
feed_id = gbfs_feed.id
77-
self.logger.info(f"Updating feed {stable_id} - {row['Name']}")
78-
else:
79-
feed_id = generate_unique_id()
80-
self.logger.info(f"Creating new feed for {stable_id} - {row['Name']}")
81-
gbfs_feed = Gbfsfeed(
82-
id=feed_id,
83-
data_type="gbfs",
84-
stable_id=stable_id,
85-
created_at=datetime.now(pytz.utc),
86-
)
87-
gbfs_feed.externalids = [self.get_external_id(feed_id, row["System ID"])]
88-
self.db.session.add(gbfs_feed)
89-
90-
system_information_content = get_data_content(fetched_data.get("system_information"), self.logger)
91-
gbfs_feed.license_url = get_license_url(system_information_content, self.logger)
92-
gbfs_feed.feed_contact_email = (
93-
system_information_content.get("feed_contact_email") if system_information_content else None
94-
)
95-
gbfs_feed.operator = row["Name"]
96-
gbfs_feed.operator_url = row["URL"]
97-
gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"]
98-
gbfs_feed.updated_at = datetime.now(pytz.utc)
99-
100-
country_code = self.get_safe_value(row, "Country Code", "")
101-
municipality = self.get_safe_value(row, "Location", "")
102-
location_id = self.get_location_id(country_code, None, municipality)
103-
country = pycountry.countries.get(alpha_2=country_code) if country_code else None
104-
location = self.db.session.get(Location, location_id) or Location(
105-
id=location_id,
106-
country_code=country_code,
107-
country=country.name if country else None,
108-
municipality=municipality,
109-
)
110-
gbfs_feed.locations.clear()
111-
gbfs_feed.locations = [location]
112-
113-
# Add the GBFS versions
114-
versions = get_gbfs_versions(
115-
fetched_data.get("gbfs_versions"), row["Auto-Discovery URL"], fetched_data.get("version"), self.logger
116-
)
117-
existing_versions = [version.version for version in gbfs_feed.gbfsversions]
118-
for version in versions:
119-
version_value = version.get("version")
120-
if version_value.upper() in OFFICIAL_VERSIONS and version_value not in existing_versions:
121-
gbfs_feed.gbfsversions.append(
122-
Gbfsversion(
123-
feed_id=feed_id,
124-
url=version.get("url"),
125-
version=version_value,
126-
)
61+
self.deprecate_feeds(deprecated_feeds)
62+
if added_or_updated_feeds is None:
63+
added_or_updated_feeds = self.df
64+
for index, row in added_or_updated_feeds.iterrows():
65+
self.logger.info(f"Processing row {index + 1} of {len(added_or_updated_feeds)}")
66+
stable_id = self.get_stable_id(row)
67+
gbfs_feed = self.query_feed_by_stable_id(session, stable_id, "gbfs")
68+
fetched_data = fetch_data(
69+
row["Auto-Discovery URL"], self.logger, ["system_information", "gbfs_versions"], ["version"]
12770
)
71+
# If the feed already exists, update it. Otherwise, create a new feed.
72+
if gbfs_feed:
73+
feed_id = gbfs_feed.id
74+
self.logger.info(f"Updating feed {stable_id} - {row['Name']}")
75+
else:
76+
feed_id = generate_unique_id()
77+
self.logger.info(f"Creating new feed for {stable_id} - {row['Name']}")
78+
gbfs_feed = Gbfsfeed(
79+
id=feed_id,
80+
data_type="gbfs",
81+
stable_id=stable_id,
82+
created_at=datetime.now(pytz.utc),
83+
)
84+
gbfs_feed.externalids = [self.get_external_id(feed_id, row["System ID"])]
85+
session.add(gbfs_feed)
12886

129-
self.db.session.flush()
130-
self.logger.info(80 * "-")
131-
132-
self.db.session.commit()
133-
end_time = datetime.now()
134-
self.logger.info(f"Time taken: {end_time - start_time} seconds")
87+
system_information_content = get_data_content(fetched_data.get("system_information"), self.logger)
88+
gbfs_feed.license_url = get_license_url(system_information_content, self.logger)
89+
gbfs_feed.feed_contact_email = (
90+
system_information_content.get("feed_contact_email") if system_information_content else None
91+
)
92+
gbfs_feed.operator = row["Name"]
93+
gbfs_feed.operator_url = row["URL"]
94+
gbfs_feed.auto_discovery_url = row["Auto-Discovery URL"]
95+
gbfs_feed.updated_at = datetime.now(pytz.utc)
96+
97+
country_code = self.get_safe_value(row, "Country Code", "")
98+
municipality = self.get_safe_value(row, "Location", "")
99+
location_id = self.get_location_id(country_code, None, municipality)
100+
country = pycountry.countries.get(alpha_2=country_code) if country_code else None
101+
location = session.get(Location, location_id) or Location(
102+
id=location_id,
103+
country_code=country_code,
104+
country=country.name if country else None,
105+
municipality=municipality,
106+
)
107+
gbfs_feed.locations.clear()
108+
gbfs_feed.locations = [location]
109+
110+
# Add the GBFS versions
111+
versions = get_gbfs_versions(
112+
fetched_data.get("gbfs_versions"),
113+
row["Auto-Discovery URL"],
114+
fetched_data.get("version"),
115+
self.logger,
116+
)
117+
existing_versions = [version.version for version in gbfs_feed.gbfsversions]
118+
for version in versions:
119+
version_value = version.get("version")
120+
if version_value.upper() in OFFICIAL_VERSIONS and version_value not in existing_versions:
121+
gbfs_feed.gbfsversions.append(
122+
Gbfsversion(
123+
feed_id=feed_id,
124+
url=version.get("url"),
125+
version=version_value,
126+
)
127+
)
128+
129+
# self.db.session.flush()
130+
self.logger.info(80 * "-")
131+
132+
# self.db.session.commit()
133+
end_time = datetime.now()
134+
self.logger.info(f"Time taken: {end_time - start_time} seconds")
135+
except Exception as e:
136+
self.logger.error(f"Error populating the database: {e}")
137+
raise e
135138

136139

137140
if __name__ == "__main__":

0 commit comments

Comments
 (0)