Skip to content

Commit be75b18

Browse files
authored
feat: create pmtiles for new datasets + update location extraction (#1322)
1 parent 6f9d0e9 commit be75b18

File tree

23 files changed

+986
-471
lines changed

23 files changed

+986
-471
lines changed

api/src/shared/common/gcp_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,11 @@ def create_http_task_with_name(
7979
task_name: str,
8080
task_time: Timestamp,
8181
http_method: "tasks_v2.HttpMethod",
82+
timeout_s: int = 1800, # 30 minutes
8283
):
8384
"""Creates a GCP Cloud Task."""
85+
from google.protobuf import duration_pb2
86+
8487
token = tasks_v2.OidcToken(service_account_email=os.getenv("SERVICE_ACCOUNT_EMAIL"))
8588

8689
parent = client.queue_path(project_id, gcp_region, queue_name)
@@ -98,6 +101,7 @@ def create_http_task_with_name(
98101
body=body,
99102
headers={"Content-Type": "application/json"},
100103
),
104+
dispatch_deadline=duration_pb2.Duration(seconds=timeout_s),
101105
)
102106
try:
103107
response = client.create_task(parent=parent, task=task)

functions-python/batch_process_dataset/src/main.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
get_hash_from_file,
4242
download_from_gcs,
4343
)
44+
from pipeline_tasks import create_pipeline_tasks
4445

4546
init_logger()
4647

@@ -260,7 +261,10 @@ def upload_dataset(self, public=True) -> DatasetFile or None:
260261
os.remove(temp_file_path)
261262
return None
262263

263-
def process_from_bucket_latest(self, public=True) -> DatasetFile or None:
264+
@with_db_session
265+
def process_from_bucket_latest(
266+
self, db_session, public=True
267+
) -> DatasetFile or None:
264268
"""
265269
Uploads a dataset to a GCP bucket as <feed_stable_id>/latest.zip and
266270
<feed_stable_id>/<feed_stable_id>-<upload_datetime>.zip
@@ -300,7 +304,10 @@ def process_from_bucket_latest(self, public=True) -> DatasetFile or None:
300304
else None
301305
),
302306
)
303-
self.create_dataset_entities(dataset_file, skip_dataset_creation=True)
307+
dataset = self.create_dataset_entities(
308+
dataset_file, skip_dataset_creation=True, db_session=db_session
309+
)
310+
create_pipeline_tasks(dataset)
304311
finally:
305312
if temp_file_path and os.path.exists(temp_file_path):
306313
os.remove(temp_file_path)
@@ -352,6 +359,7 @@ def create_dataset_entities(
352359
self.logger.info(
353360
f"[{self.feed_stable_id}] Creating new dataset for feed with stable id {dataset_file.stable_id}."
354361
)
362+
dataset = None
355363
if not skip_dataset_creation:
356364
dataset = Gtfsdataset(
357365
id=str(uuid.uuid4()),
@@ -394,10 +402,12 @@ def create_dataset_entities(
394402
self.logger.info(f"[{self.feed_stable_id}] Dataset created successfully.")
395403

396404
create_refresh_materialized_view_task()
405+
return latest_dataset if skip_dataset_creation else dataset
397406
except Exception as e:
398407
raise Exception(f"Error creating dataset: {e}")
399408

400-
def process_from_producer_url(self) -> DatasetFile or None:
409+
@with_db_session
410+
def process_from_producer_url(self, db_session) -> DatasetFile or None:
401411
"""
402412
Process the dataset and store new version in GCP bucket if any changes are detected
403413
:return: the file hash and the hosted url as a tuple or None if no upload is required
@@ -407,7 +417,8 @@ def process_from_producer_url(self) -> DatasetFile or None:
407417
if dataset_file is None:
408418
self.logger.info(f"[{self.feed_stable_id}] No database update required.")
409419
return None
410-
self.create_dataset_entities(dataset_file)
420+
dataset = self.create_dataset_entities(dataset_file, db_session=db_session)
421+
create_pipeline_tasks(dataset)
411422
return dataset_file
412423

413424

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import json
2+
import logging
3+
import os
4+
from typing import Iterable, List
5+
6+
from google.cloud import tasks_v2
7+
from sqlalchemy.orm import Session
8+
9+
from shared.database.database import with_db_session
10+
from shared.database_gen.sqlacodegen_models import Gtfsdataset
11+
from shared.helpers.utils import create_http_task
12+
13+
14+
def create_http_reverse_geolocation_processor_task(
15+
stable_id: str,
16+
dataset_stable_id: str,
17+
stops_url: str,
18+
) -> None:
19+
"""
20+
Create a task to process reverse geolocation for a dataset.
21+
"""
22+
client = tasks_v2.CloudTasksClient()
23+
body = json.dumps(
24+
{
25+
"stable_id": stable_id,
26+
"stops_url": stops_url,
27+
"dataset_id": dataset_stable_id,
28+
}
29+
).encode()
30+
queue_name = os.getenv("REVERSE_GEOLOCATION_QUEUE")
31+
project_id = os.getenv("PROJECT_ID")
32+
gcp_region = os.getenv("GCP_REGION")
33+
34+
create_http_task(
35+
client,
36+
body,
37+
f"https://{gcp_region}-{project_id}.cloudfunctions.net/reverse-geolocation-processor",
38+
project_id,
39+
gcp_region,
40+
queue_name,
41+
)
42+
43+
44+
def create_http_pmtiles_builder_task(
45+
stable_id: str,
46+
dataset_stable_id: str,
47+
) -> None:
48+
"""
49+
Create a task to generate PMTiles for a dataset.
50+
"""
51+
client = tasks_v2.CloudTasksClient()
52+
body = json.dumps(
53+
{"feed_stable_id": stable_id, "dataset_stable_id": dataset_stable_id}
54+
).encode()
55+
queue_name = os.getenv("PMTILES_BUILDER_QUEUE")
56+
project_id = os.getenv("PROJECT_ID")
57+
gcp_region = os.getenv("GCP_REGION")
58+
gcp_env = os.getenv("ENVIRONMENT")
59+
60+
create_http_task(
61+
client,
62+
body,
63+
f"https://{gcp_region}-{project_id}.cloudfunctions.net/pmtiles-builder-{gcp_env}",
64+
project_id,
65+
gcp_region,
66+
queue_name,
67+
)
68+
69+
70+
@with_db_session
71+
def get_changed_files(
72+
dataset: Gtfsdataset,
73+
db_session: Session,
74+
) -> List[str]:
75+
"""
76+
Return the subset of `file_names` whose content hash changed compared to the
77+
previous dataset for the same feed.
78+
- If there is no previous dataset → any file that exists in the new dataset is considered "changed".
79+
- If the file existed before and now is missing → NOT considered changed.
80+
- If the file did not exist before but exists now → considered changed.
81+
- If hashes differ → considered changed.
82+
"""
83+
previous_dataset = (
84+
db_session.query(Gtfsdataset)
85+
.filter(
86+
Gtfsdataset.feed_id == dataset.feed_id,
87+
Gtfsdataset.id != dataset.id,
88+
)
89+
.order_by(Gtfsdataset.downloaded_at.desc())
90+
.first()
91+
)
92+
93+
new_files = list(dataset.gtfsfiles)
94+
95+
# No previous dataset -> everything that exists now is "changed"
96+
if not previous_dataset:
97+
return [f.file_name for f in new_files]
98+
99+
prev_map = {
100+
f.file_name: getattr(f, "hash", None) for f in previous_dataset.gtfsfiles
101+
}
102+
103+
changed_files = []
104+
for f in new_files:
105+
new_hash = getattr(f, "hash", None)
106+
old_hash = prev_map.get(f.file_name)
107+
108+
if old_hash is None or old_hash != new_hash:
109+
changed_files.append(f)
110+
logging.info(f"Changed file {f.file_name} from {old_hash} to {new_hash}")
111+
112+
return [f.file_name for f in changed_files]
113+
114+
115+
@with_db_session
116+
def create_pipeline_tasks(dataset: Gtfsdataset, db_session: Session) -> None:
117+
"""
118+
Create pipeline tasks for a dataset.
119+
"""
120+
changed_files = get_changed_files(dataset, db_session=db_session)
121+
122+
stable_id = dataset.feed.stable_id
123+
dataset_stable_id = dataset.stable_id
124+
gtfs_files = dataset.gtfsfiles
125+
stops_file = next(
126+
(file for file in gtfs_files if file.file_name == "stops.txt"), None
127+
)
128+
stops_url = stops_file.hosted_url if stops_file else None
129+
130+
# Create reverse geolocation task
131+
if stops_url and "stops.txt" in changed_files:
132+
create_http_reverse_geolocation_processor_task(
133+
stable_id, dataset_stable_id, stops_url
134+
)
135+
136+
routes_file = next(
137+
(file for file in gtfs_files if file.file_name == "routes.txt"), None
138+
)
139+
# Create PMTiles builder task
140+
required_files = {"stops.txt", "routes.txt", "trips.txt", "stop_times.txt"}
141+
if not required_files.issubset(set(f.file_name for f in gtfs_files)):
142+
logging.info(
143+
f"Skipping PMTiles task for dataset {dataset_stable_id} due to missing required files. Required files: "
144+
f"{required_files}, available files: {[f.file_name for f in gtfs_files]}"
145+
)
146+
expected_file_change: Iterable[str] = {
147+
"stops.txt",
148+
"trips.txt",
149+
"routes.txt",
150+
"stop_times.txt",
151+
"shapes.txt",
152+
}
153+
if (
154+
routes_file
155+
and 0 < routes_file.file_size_bytes < 1_000_000
156+
and not set(changed_files).isdisjoint(expected_file_change)
157+
):
158+
create_http_pmtiles_builder_task(stable_id, dataset_stable_id)
159+
elif routes_file:
160+
logging.info(
161+
f"Skipping PMTiles task for dataset {dataset_stable_id} due to constraints --> "
162+
f"routes.txt file size : {routes_file.file_size_bytes} bytes"
163+
f" and changed files: {changed_files}"
164+
)

functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def test_process_dataset_missing_stable_id(self, mock_dataset_trace):
451451
)
452452

453453
@patch.dict(os.environ, {"DATASETS_BUCKET_NAME": "test-bucket"})
454+
@patch("main.create_pipeline_tasks")
454455
@patch("main.DatasetProcessor.create_dataset_entities")
455456
@patch("main.DatasetProcessor.upload_files_to_storage")
456457
@patch("main.DatasetProcessor.unzip_files")
@@ -461,6 +462,7 @@ def test_process_from_bucket_latest_happy_path(
461462
mock_unzip_files,
462463
mock_upload_files_to_storage,
463464
mock_create_dataset_entities,
465+
_,
464466
):
465467
# Arrange
466468
mock_blob = MagicMock()

0 commit comments

Comments
 (0)