Skip to content

Commit c4d1fe4

Browse files
committed
Refactor python functions
1 parent b09e130 commit c4d1fe4

File tree

23 files changed

+153
-108
lines changed

23 files changed

+153
-108
lines changed

api/src/shared/common/db_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def get_all_gtfs_feeds(
170170
feed_query = apply_most_common_location_filter(db_session.query(Gtfsfeed), db_session)
171171
yield from (
172172
feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids)).options(
173-
contains_eager(Gtfsfeed.latest_dataset)
173+
joinedload(Gtfsfeed.latest_dataset)
174174
.joinedload(Gtfsdataset.validation_reports)
175175
.joinedload(Validationreport.features),
176176
*get_joinedload_options(include_extracted_location_entities=True),
@@ -182,7 +182,7 @@ def get_all_gtfs_feeds(
182182
.outerjoin(Gtfsfeed.gtfsdatasets)
183183
.filter(Gtfsfeed.stable_id.in_(stable_ids))
184184
.options(
185-
contains_eager(Gtfsfeed.latest_dataset)
185+
joinedload(Gtfsfeed.latest_dataset)
186186
.joinedload(Gtfsdataset.validation_reports)
187187
.joinedload(Validationreport.features),
188188
*get_joinedload_options(include_extracted_location_entities=False),

functions-python/batch_datasets/src/main.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from google.cloud import pubsub_v1
2626
from google.cloud.pubsub_v1 import PublisherClient
2727
from google.cloud.pubsub_v1.futures import Future
28-
from sqlalchemy import or_
2928
from sqlalchemy.orm import Session
3029

3130
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsdataset
@@ -87,9 +86,8 @@ def get_non_deprecated_feeds(
8786
Gtfsdataset.hash.label("dataset_hash"),
8887
)
8988
.select_from(Gtfsfeed)
90-
.outerjoin(Gtfsdataset, (Gtfsdataset.feed_id == Gtfsfeed.id))
89+
.outerjoin(Gtfsdataset, (Gtfsfeed.latest_dataset_id == Gtfsdataset.id))
9190
.filter(Gtfsfeed.status != "deprecated")
92-
.filter(or_(Gtfsdataset.id.is_(None), Gtfsdataset.latest.is_(True)))
9391
)
9492
if feed_stable_ids:
9593
# If feed_stable_ids are provided, filter the query by stable IDs

functions-python/batch_datasets/tests/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ def populate_database(db_session: Session | None = None):
8484
# GTFS datasets leaving one active feed without a dataset
8585
active_gtfs_feeds = db_session.query(Gtfsfeed).all()
8686
for i in range(1, 9):
87+
id = fake.uuid4()
8788
gtfs_dataset = Gtfsdataset(
88-
id=fake.uuid4(),
89+
id=id,
8990
feed_id=active_gtfs_feeds[i].id,
90-
latest=True,
9191
bounding_box="POLYGON((-180 -90, -180 90, 180 90, 180 -90, -180 -90))",
9292
hosted_url=fake.url(),
9393
note=fake.sentence(),
@@ -96,6 +96,8 @@ def populate_database(db_session: Session | None = None):
9696
stable_id=fake.uuid4(),
9797
)
9898
db_session.add(gtfs_dataset)
99+
db_session.flush()
100+
active_gtfs_feeds[i].latest_dataset_id = id
99101

100102
db_session.flush()
101103
# GTFS Realtime feeds

functions-python/batch_process_dataset/src/main.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from shared.common.gcp_utils import create_refresh_materialized_view_task
3535
from shared.database.database import with_db_session
36-
from shared.database_gen.sqlacodegen_models import Gtfsdataset, Gtfsfile
36+
from shared.database_gen.sqlacodegen_models import Gtfsdataset, Gtfsfile, Gtfsfeed
3737
from shared.dataset_service.main import DatasetTraceService, DatasetTrace, Status
3838
from shared.helpers.logger import init_logger, get_logger
3939
from shared.helpers.utils import (
@@ -84,7 +84,9 @@ def __init__(
8484
self.api_key_parameter_name = api_key_parameter_name
8585
self.date = datetime.now().strftime("%Y%m%d%H%M")
8686
if self.authentication_type != 0:
87-
self.logger.info(f"Getting feed credentials for feed {self.feed_stable_id}")
87+
self.logger.info(
88+
"Getting feed credentials for feed %s", self.feed_stable_id
89+
)
8890
self.feed_credentials = self.get_feed_credentials(self.feed_stable_id)
8991
if self.feed_credentials is None:
9092
raise Exception(
@@ -135,7 +137,7 @@ def download_content(self, temporary_file_path, feed_id):
135137
credentials=self.feed_credentials,
136138
logger=self.logger,
137139
)
138-
self.logger.info(f"hash is: {file_hash}")
140+
self.logger.info("hash is: %s", file_hash)
139141
is_zip = zipfile.is_zipfile(temporary_file_path)
140142
return file_hash, is_zip
141143

@@ -168,7 +170,7 @@ def upload_files_to_storage(
168170
extracted_files: List[Gtfsfile] = []
169171
if not extracted_files_path or not os.path.exists(extracted_files_path):
170172
self.logger.warning(
171-
f"Extracted files path {extracted_files_path} does not exist."
173+
"Extracted files path %s does not exist.", extracted_files_path
172174
)
173175
return blob, extracted_files
174176
self.logger.info("Processing extracted files from %s", extracted_files_path)
@@ -182,7 +184,7 @@ def upload_files_to_storage(
182184
if public:
183185
file_blob.make_public()
184186
self.logger.info(
185-
f"Uploaded extracted file {file_name} to {file_blob.public_url}"
187+
"Uploaded extracted file %s to %s", file_name, file_blob.public_url
186188
)
187189
extracted_files.append(
188190
Gtfsfile(
@@ -209,7 +211,8 @@ def upload_dataset(self, feed_id, public=True) -> DatasetFile or None:
209211
file_sha256_hash, is_zip = self.download_content(temp_file_path, feed_id)
210212
if not is_zip:
211213
self.logger.error(
212-
f"[{self.feed_stable_id}] The downloaded file from {self.producer_url} is not a valid ZIP file."
214+
"The downloaded file from %s is not a valid ZIP file.",
215+
self.producer_url,
213216
)
214217
return None
215218

@@ -299,17 +302,18 @@ def process_from_bucket(self, db_session, public=True) -> Optional[DatasetFile]:
299302
else None
300303
),
301304
)
302-
dataset = self.create_dataset_entities(
305+
dataset, latest = self.create_dataset_entities(
303306
dataset_file, skip_dataset_creation=True, db_session=db_session
304307
)
305-
if dataset and dataset.latest:
308+
if dataset and latest:
306309
self.logger.info(
307-
f"Creating pipeline tasks for latest dataset {dataset.stable_id}"
310+
"Creating pipeline tasks for latest dataset %s", dataset.stable_id
308311
)
309312
create_pipeline_tasks(dataset)
310313
elif dataset:
311314
self.logger.info(
312-
f"Dataset {dataset.stable_id} is not the latest, skipping pipeline tasks creation."
315+
"Dataset %s is not the latest, skipping pipeline tasks creation.",
316+
dataset.stable_id,
313317
)
314318
else:
315319
raise ValueError("Dataset update failed, dataset is None.")
@@ -352,26 +356,24 @@ def create_dataset_entities(
352356
"""
353357
try:
354358
# Check latest version of the dataset
355-
latest_dataset = (
356-
db_session.query(Gtfsdataset)
357-
.filter_by(latest=True, feed_id=self.feed_id)
358-
.one_or_none()
359+
gtfs_feed: Gtfsfeed | None = (
360+
db_session.query(Gtfsfeed).filter_by(id=self.feed_id).one_or_none()
359361
)
362+
latest_dataset = gtfs_feed.latest_dataset
360363
if not latest_dataset:
361-
self.logger.info(
362-
f"[{self.feed_stable_id}] No latest dataset found for feed."
363-
)
364+
self.logger.info("No latest dataset found for feed.")
364365

365366
dataset = None
367+
latest = True if latest_dataset is not None else False
366368
if not skip_dataset_creation:
367369
self.logger.info(
368-
f"[{self.feed_stable_id}] Creating new dataset for feed with stable id {dataset_file.stable_id}."
370+
"Creating new dataset for feed with stable id %s.",
371+
dataset_file.stable_id,
369372
)
370373
dataset = Gtfsdataset(
371374
id=str(uuid.uuid4()),
372375
feed_id=self.feed_id,
373376
stable_id=dataset_file.stable_id,
374-
latest=True,
375377
bounding_box=None,
376378
note=None,
377379
hash=dataset_file.file_sha256_hash,
@@ -386,10 +388,14 @@ def create_dataset_entities(
386388
unzipped_size_bytes=self._get_unzipped_size(dataset_file),
387389
)
388390
db_session.add(dataset)
391+
# update the latest dataset relationship in the feed
392+
db_session.flush()
393+
gtfs_feed.latest_dataset = dataset
394+
latest = True
389395
elif skip_dataset_creation and latest_dataset:
390396
self.logger.info(
391-
f"[{self.feed_stable_id}] Updating latest dataset for feed with stable id "
392-
f"{latest_dataset.stable_id}."
397+
"Updating latest dataset for feed with stable id %s",
398+
latest_dataset.stable_id,
393399
)
394400
latest_dataset.gtfsfiles = (
395401
dataset_file.extracted_files if dataset_file.extracted_files else []
@@ -400,13 +406,12 @@ def create_dataset_entities(
400406
)
401407

402408
if latest_dataset and not skip_dataset_creation:
403-
latest_dataset.latest = False
404409
db_session.add(latest_dataset)
405410
db_session.commit()
406-
self.logger.info(f"[{self.feed_stable_id}] Dataset created successfully.")
411+
self.logger.info("Dataset created successfully.")
407412

408413
create_refresh_materialized_view_task()
409-
return latest_dataset if skip_dataset_creation else dataset
414+
return latest_dataset if skip_dataset_creation else dataset, latest
410415
except Exception as e:
411416
raise Exception(f"Error creating dataset: {e}")
412417

@@ -431,7 +436,7 @@ def process_from_producer_url(
431436
if dataset_file is None:
432437
self.logger.info(f"[{self.feed_stable_id}] No database update required.")
433438
return None
434-
dataset = self.create_dataset_entities(dataset_file, db_session=db_session)
439+
dataset, _ = self.create_dataset_entities(dataset_file, db_session=db_session)
435440
create_pipeline_tasks(dataset)
436441
return dataset_file
437442

@@ -577,7 +582,8 @@ def process_dataset(cloud_event: CloudEvent):
577582
)
578583
return f"Function completed with errors, missing stable={stable_id} or execution_id={execution_id}"
579584
logger.info(
580-
f"Function %s in execution: [{execution_id}]",
585+
"Function %s in execution: %s",
586+
execution_id,
581587
"successfully completed" if not error_message else "Failed",
582588
)
583589
return "Completed." if error_message is None else error_message

functions-python/batch_process_dataset/tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def populate_database(db_session):
6464
gtfs_dataset = Gtfsdataset(
6565
id=fake.uuid4(),
6666
feed_id=active_gtfs_feeds[i].id,
67-
latest=True,
6867
bounding_box="POLYGON((-180 -90, -180 90, 180 90, 180 -90, -180 -90))",
6968
hosted_url=fake.url(),
7069
note=fake.sentence(),
@@ -73,6 +72,8 @@ def populate_database(db_session):
7372
stable_id=fake.uuid4(),
7473
)
7574
db_session.add(gtfs_dataset)
75+
db_session.flush()
76+
active_gtfs_feeds[i].latest_gtfsdataset_id = gtfs_dataset.id
7677

7778
db_session.flush()
7879
# GTFS Realtime feeds

functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def test_process_from_bucket_latest_happy_path(
487487
dataset_stable_id="dataset-stable-id-123", # REQUIRED for bucket-latest path
488488
)
489489

490+
mock_create_dataset_entities.return_value = Mock(), True
490491
# Act
491492
result = processor.process_from_bucket(public=True)
492493

functions-python/export_csv/src/main.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,7 @@ def get_gtfs_feed_csv_data(
214214
data = get_feed_csv_data(feed, geopolygon_map)
215215

216216
# Then supplement with the GTFS specific data
217-
latest_dataset = next(
218-
(
219-
dataset
220-
for dataset in (feed.gtfsdatasets or [])
221-
if dataset and dataset.latest
222-
),
223-
None,
224-
)
217+
latest_dataset = feed.latest_dataset
225218
if latest_dataset and latest_dataset.validation_reports:
226219
# Keep the report from the more recent validator version
227220
latest_report = max(

functions-python/export_csv/tests/conftest.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def populate_database(db_session):
6969
official=True,
7070
)
7171
feeds.append(feed)
72-
72+
db_session.flush()
7373
# Then fill the specific parameters for each feed
7474
target_feed = feeds[0]
7575
target_feed.id = "e3155a30-81d8-40bb-9e10-013a60436d86" # Just an invented uuid
@@ -148,8 +148,7 @@ def populate_database(db_session):
148148
feed_stable_id = active_gtfs_feeds[feed_index].stable_id
149149
gtfs_dataset = Gtfsdataset(
150150
id=fake.uuid4(),
151-
feed_id=feed_stable_id,
152-
latest=True if i != 2 else False,
151+
feed_id=active_gtfs_feeds[feed_index].id,
153152
bounding_box=wkt_element,
154153
# Use a url containing the stable id. The program should replace all the is after the feed stable id
155154
# by latest.zip
@@ -159,6 +158,8 @@ def populate_database(db_session):
159158
downloaded_at=datetime(2025, 1, 12),
160159
stable_id=f"dataset-{i}",
161160
)
161+
db_session.add(gtfs_dataset)
162+
db_session.flush()
162163
validation_report = Validationreport(
163164
id=fake.uuid4(),
164165
validator_version="6.0.1",
@@ -175,6 +176,8 @@ def populate_database(db_session):
175176
gtfs_dataset.locations = locations
176177

177178
active_gtfs_feeds[feed_index].gtfsdatasets.append(gtfs_dataset)
179+
if i != 2:
180+
active_gtfs_feeds[feed_index].latest_dataset_id = gtfs_dataset.id
178181
db_session.flush()
179182
active_gtfs_feeds[feed_index].bounding_box = gtfs_dataset.bounding_box
180183
active_gtfs_feeds[feed_index].bounding_box_dataset_id = gtfs_dataset.id

functions-python/helpers/feed_status.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from datetime import datetime, timezone
33
from sqlalchemy import text
4-
from shared.database_gen.sqlacodegen_models import Gtfsdataset, Feed
4+
from shared.database_gen.sqlacodegen_models import Gtfsdataset, Feed, Gtfsfeed
55
from typing import TYPE_CHECKING
66

77
if TYPE_CHECKING:
@@ -19,8 +19,8 @@ def update_feed_statuses_query(session: "Session", stable_feed_ids: list[str]):
1919
Gtfsdataset.service_date_range_start,
2020
Gtfsdataset.service_date_range_end,
2121
)
22+
.join(Gtfsfeed, Gtfsfeed.latest_dataset_id == Gtfsdataset.id)
2223
.filter(
23-
Gtfsdataset.latest.is_(True),
2424
Gtfsdataset.service_date_range_start.isnot(None),
2525
Gtfsdataset.service_date_range_end.isnot(None),
2626
)

functions-python/helpers/query_helper.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,9 @@ def get_feeds_with_missing_bounding_boxes_query(
190190
"""
191191
query = (
192192
db_session.query(Gtfsfeed)
193-
.join(Gtfsdataset, Gtfsdataset.feed_id == Gtfsfeed.id)
194-
.filter(Gtfsdataset.latest.is_(True))
195-
.filter(Gtfsdataset.bounding_box.is_(None))
193+
.filter(Gtfsfeed.bounding_box.is_(None))
196194
.filter(~Gtfsfeed.feedlocationgrouppoints.any())
197-
.distinct(Gtfsfeed.stable_id, Gtfsdataset.stable_id)
198-
.order_by(Gtfsdataset.stable_id, Gtfsfeed.stable_id)
195+
.distinct(Gtfsfeed.stable_id)
196+
.order_by(Gtfsfeed.stable_id)
199197
)
200198
return query

0 commit comments

Comments
 (0)