Skip to content

Commit 749284b

Browse files
authored
feat: refactor Pmtiles build using visualization_dataset_id to check existence (#1388)
1 parent ee2303d commit 749284b

File tree

5 files changed

+85
-128
lines changed

5 files changed

+85
-128
lines changed

functions-python/pmtiles_builder/src/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def build_pmtiles(self):
306306
self._upload_files_to_gcs(files_to_upload)
307307
self._update_database()
308308

309+
self.logger.info("Completed PMTiles build")
309310
return self.OperationStatus.SUCCESS, "success"
310311

311312
def _download_files_from_gcs(self, unzipped_files_path):

functions-python/tasks_executor/src/tasks/visualization_files/rebuild_missing_visualization_files.py

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
#
1616

1717
import logging
18-
import os
1918
from typing import List, Final, Optional
2019

21-
from google.cloud import storage
2220
from sqlalchemy import func, distinct
2321
from sqlalchemy.orm import Session, selectinload
2422

@@ -32,11 +30,6 @@
3230
"trips.txt",
3331
"stop_times.txt",
3432
]
35-
PMTILES_FILES: Final[List[str]] = [
36-
"pmtiles/stops.pmtiles",
37-
"pmtiles/routes.pmtiles",
38-
"pmtiles/routes.json",
39-
]
4033

4134

4235
def rebuild_missing_visualization_files_handler(payload) -> dict:
@@ -49,6 +42,8 @@ def rebuild_missing_visualization_files_handler(payload) -> dict:
4942
"check_existing": bool, # [optional] If True, check if visualization files already exist before creating tasks
5043
"latest_only": bool, # [optional] If True, include only latest datasets
5144
"include_deprecated_feeds": bool, # [optional] If True, include datasets from deprecated feeds
45+
"include_feed_op_status": list[str], # [optional] List of feed operational statuses to include
46+
# e.g., ["published", "wip"]. Default is ["published"].
5247
"limit": int, # [optional] Limit the number of datasets to process
5348
}
5449
Args:
@@ -58,41 +53,41 @@ def rebuild_missing_visualization_files_handler(payload) -> dict:
5853
"""
5954
(
6055
dry_run,
61-
bucket_name,
6256
check_existing,
6357
latest_only,
6458
include_deprecated_feeds,
59+
include_feed_op_status,
6560
limit,
6661
) = get_parameters(payload)
6762

6863
return rebuild_missing_visualization_files(
6964
dry_run=dry_run,
70-
bucket_name=bucket_name,
7165
check_existing=check_existing,
7266
latest_only=latest_only,
67+
include_feed_op_status=include_feed_op_status,
7368
include_deprecated_feeds=include_deprecated_feeds,
7469
limit=limit,
7570
)
7671

7772

7873
@with_db_session
7974
def rebuild_missing_visualization_files(
80-
bucket_name: str,
8175
dry_run: bool = True,
8276
check_existing: bool = True,
8377
latest_only: bool = True,
8478
include_deprecated_feeds: bool = False,
79+
include_feed_op_status: list[str] = ["published"],
8580
limit: Optional[int] = None,
8681
db_session: Session | None = None,
8782
) -> dict:
8883
"""
8984
Rebuilds missing visualization files for GTFS datasets.
9085
Args:
91-
bucket_name (str): The name of the bucket containing the GTFS data.
9286
dry_run (bool): dry run flag. If True, do not execute the workflow. Default: True
9387
check_existing (bool): If True, check if visualization files already exist before creating tasks. Default: True
9488
latest_only (bool): If True, include only latest datasets. Default: True
9589
include_deprecated_feeds (bool): If True, include datasets from deprecated feeds. Default: False
90+
include_feed_op_status (list[str]): List of feed operational statuses to include. Default: ['published']
9691
limit (Optional[int]): Limit the number of datasets to process. Default: None (no limit)
9792
db_session: DB session
9893
@@ -107,7 +102,16 @@ def rebuild_missing_visualization_files(
107102
datasets_query = datasets_query.filter(
108103
Gtfsdataset.feed.has(Gtfsfeed.status != "deprecated")
109104
)
110-
105+
if include_feed_op_status:
106+
datasets_query = datasets_query.filter(
107+
Gtfsdataset.feed.has(
108+
Gtfsfeed.operational_status.in_(include_feed_op_status)
109+
)
110+
)
111+
if check_existing:
112+
datasets_query = datasets_query.join(
113+
Gtfsfeed, Gtfsdataset.feed_id == Gtfsfeed.id
114+
).filter(Gtfsfeed.visualization_dataset_id.is_(None))
111115
datasets_query = (
112116
datasets_query.join(Gtfsdataset.gtfsfiles)
113117
.filter(Gtfsfile.file_name.in_(REQUIRED_FILES))
@@ -122,41 +126,14 @@ def rebuild_missing_visualization_files(
122126
datasets = datasets_query.all()
123127
logging.info(f"Found {len(datasets)} latest datasets with all required files.")
124128

125-
# Validate visualization files existence in the storage bucket
126-
client = storage.Client()
127-
bucket = client.get_bucket(bucket_name)
128129
tasks_to_create = []
129130
for dataset in datasets:
130-
if not check_existing:
131-
tasks_to_create.append(
132-
{
133-
"feed_stable_id": dataset.feed.stable_id,
134-
"dataset_stable_id": dataset.stable_id,
135-
}
136-
)
137-
else:
138-
# Check if visualization files already exist
139-
all_files_exist = True
140-
for file_suffix in PMTILES_FILES:
141-
file_path = (
142-
f"{dataset.feed.stable_id}/{dataset.stable_id}/{file_suffix}"
143-
)
144-
blob = bucket.blob(file_path)
145-
if not blob.exists():
146-
all_files_exist = False
147-
logging.info(f"Missing visualization file: {file_path}")
148-
break
149-
if not all_files_exist:
150-
tasks_to_create.append(
151-
{
152-
"feed_stable_id": dataset.feed.stable_id,
153-
"dataset_stable_id": dataset.stable_id,
154-
}
155-
)
156-
else:
157-
logging.info(
158-
f"All visualization files exist for dataset {dataset.stable_id}. Skipping."
159-
)
131+
tasks_to_create.append(
132+
{
133+
"feed_stable_id": dataset.feed.stable_id,
134+
"dataset_stable_id": dataset.stable_id,
135+
}
136+
)
160137
total_processed = len(tasks_to_create)
161138
logging.info(f"Total datasets to process: {total_processed}")
162139

@@ -177,7 +154,6 @@ def rebuild_missing_visualization_files(
177154
"total_processed": total_processed,
178155
"params": {
179156
"dry_run": dry_run,
180-
"bucket_name": bucket_name,
181157
"check_existing": check_existing,
182158
"latest_only": latest_only,
183159
"include_deprecated_feeds": include_deprecated_feeds,
@@ -199,9 +175,6 @@ def get_parameters(payload):
199175
"""
200176
dry_run = payload.get("dry_run", True)
201177
dry_run = dry_run if isinstance(dry_run, bool) else str(dry_run).lower() == "true"
202-
bucket_name = os.getenv("DATASETS_BUCKET_NAME")
203-
if not bucket_name:
204-
raise EnvironmentError("DATASETS_BUCKET_NAME environment variable is not set.")
205178
check_existing = payload.get("check_existing", True)
206179
check_existing = (
207180
check_existing
@@ -220,13 +193,14 @@ def get_parameters(payload):
220193
if isinstance(include_deprecated_feeds, bool)
221194
else str(include_deprecated_feeds).lower() == "true"
222195
)
196+
include_feed_op_status = payload.get("include_feed_op_status", ["published"])
223197
limit = payload.get("limit", None)
224198
limit = limit if isinstance(limit, int) and limit > 0 else None
225199
return (
226200
dry_run,
227-
bucket_name,
228201
check_existing,
229202
latest_only,
230203
include_deprecated_feeds,
204+
include_feed_op_status,
231205
limit,
232206
)

functions-python/tasks_executor/tests/conftest.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,24 @@ def populate_database(db_session: Session | None = None):
4747
)
4848
db_session.add(feed)
4949
feeds.append(feed)
50+
wip_feed = Gtfsfeed(
51+
id="feed_wip",
52+
stable_id="stable_feed_wip_feed",
53+
data_type="gtfs",
54+
status="active",
55+
created_at=now,
56+
operational_status="wip",
57+
)
58+
with_visualization_feed = Gtfsfeed(
59+
id="feed_visualization",
60+
stable_id="stable_feed_visualization",
61+
data_type="gtfs",
62+
status="active",
63+
created_at=now,
64+
operational_status="wip",
65+
)
66+
db_session.add(wip_feed)
67+
db_session.add(with_visualization_feed)
5068
gbfs_feed = Gbfsfeed(
5169
id=f"feed_{uuid.uuid4()}",
5270
stable_id=f"stable_feed_gbfs_{uuid.uuid4()}",
@@ -70,6 +88,25 @@ def populate_database(db_session: Session | None = None):
7088
db_session.add(dataset)
7189
datasets.append(dataset)
7290

91+
wip_dataset = Gtfsdataset(
92+
id="dataset_wip",
93+
feed=wip_feed,
94+
stable_id="dataset_stable_wip",
95+
downloaded_at=now - timedelta(days=i),
96+
bounding_box="POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))",
97+
)
98+
with_visualization_dataset = Gtfsdataset(
99+
id="dataset_visualization",
100+
feed=with_visualization_feed,
101+
stable_id="dataset_stable_visualization",
102+
downloaded_at=now - timedelta(days=i),
103+
bounding_box="POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))",
104+
)
105+
db_session.add(with_visualization_dataset)
106+
db_session.add(wip_dataset)
107+
db_session.flush()
108+
with_visualization_feed.visualization_dataset_id = with_visualization_dataset.id
109+
73110
db_session.commit()
74111

75112

functions-python/tasks_executor/tests/tasks/validation_reports/test_rebuild_missing_validation_reports.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_rebuild_missing_validation_reports_one_page(
108108

109109
# Assert the expected behavior
110110
self.assertIsNotNone(response)
111-
self.assertEquals(response["total_processed"], 7)
111+
self.assertEquals(response["total_processed"], 9)
112112
self.assertEquals(
113113
response["message"],
114114
"Rebuild missing validation reports task executed successfully.",
@@ -137,12 +137,12 @@ def test_rebuild_missing_validation_reports_two_pages(
137137

138138
# Assert the expected behavior
139139
self.assertIsNotNone(response)
140-
self.assertEquals(response["total_processed"], 7)
140+
self.assertEquals(response["total_processed"], 9)
141141
self.assertEquals(
142142
response["message"],
143143
"Rebuild missing validation reports task executed successfully.",
144144
)
145-
self.assertEquals(execute_workflows_mock.call_count, 4)
145+
self.assertEquals(execute_workflows_mock.call_count, 5)
146146

147147
@with_db_session(db_url=default_db_url)
148148
@patch(
@@ -166,7 +166,7 @@ def test_rebuild_missing_validation_reports_dryrun(
166166

167167
# Assert the expected behavior
168168
self.assertIsNotNone(response)
169-
self.assertEquals(response["total_processed"], 7)
169+
self.assertEquals(response["total_processed"], 9)
170170
self.assertEquals(response["message"], "Dry run: no datasets processed.")
171171
execute_workflows_mock.assert_not_called()
172172

0 commit comments

Comments
 (0)