Skip to content

Commit a904893

Browse files
committed
testing changes
1 parent ff45cd3 commit a904893

File tree

7 files changed

+162
-60
lines changed

7 files changed

+162
-60
lines changed

functions-python/batch_datasets/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Batch Datasets
22
This directory contains the GCP serverless function that enqueue all active feeds to download datasets.
3+
The function accepts an option request body to limit the feeds to process, otherwise it processes all active feeds:
4+
```json
5+
{
6+
"feed_stable_ids": ["feed_id_1", "feed_id_2"]
7+
}
8+
```
9+
310
The function publish one Pub/Sub message per active feed with the following format:
411
```json
512
{
@@ -19,6 +26,7 @@ The function publish one Pub/Sub message per active feed with the following form
1926
}
2027
}
2128
```
29+
# TODO - Update with current behavior
2230

2331
# Function configuration
2432
The function is configured using the following environment variables:

functions-python/batch_datasets/src/main.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import uuid
2121
from datetime import datetime
22+
from typing import Optional
2223

2324
import functions_framework
2425
from google.cloud import pubsub_v1
@@ -64,7 +65,7 @@ def publish(publisher: PublisherClient, topic_path: str, data_bytes: bytes) -> F
6465
return publisher.publish(topic_path, data=data_bytes)
6566

6667

67-
def get_non_deprecated_feeds(session: Session):
68+
def get_non_deprecated_feeds(session: Session, feed_stable_ids: Optional[list[str]] = None):
6869
"""
6970
Returns a list of non deprecated feeds
7071
:return: list of feeds
@@ -79,14 +80,17 @@ def get_non_deprecated_feeds(session: Session):
7980
Gtfsfeed.authentication_info_url,
8081
Gtfsfeed.api_key_parameter_name,
8182
Gtfsfeed.status,
82-
Gtfsdataset.id.label("dataset_id"),
83+
Gtfsdataset.stable_id.label("dataset_stable_id"),
8384
Gtfsdataset.hash.label("dataset_hash"),
8485
)
8586
.select_from(Gtfsfeed)
8687
.outerjoin(Gtfsdataset, (Gtfsdataset.feed_id == Gtfsfeed.id))
8788
.filter(Gtfsfeed.status != "deprecated")
8889
.filter(or_(Gtfsdataset.id.is_(None), Gtfsdataset.latest.is_(True)))
8990
)
91+
if feed_stable_ids:
92+
# If feed_stable_ids are provided, filter the query by stable IDs
93+
query = query.filter(Gtfsfeed.stable_id.in_(feed_stable_ids))
9094
# Limit the query to 10 feeds (or FEEDS_LIMIT param) for testing purposes and lower environments
9195
if os.getenv("ENVIRONMENT", "").lower() in ("dev", "test", "qa"):
9296
limit = os.getenv("FEEDS_LIMIT")
@@ -108,8 +112,15 @@ def batch_datasets(request, db_session: Session):
108112
:param db_session: database session object
109113
:return: HTTP response object
110114
"""
115+
feed_stable_ids = None
111116
try:
112-
feeds = get_non_deprecated_feeds(db_session)
117+
request_json = request.get_json()
118+
feed_stable_ids = request_json.get("feed_stable_ids") if request_json else None
119+
except Exception:
120+
logging.info(f"No feed_stable_ids provided in the request, processing all feeds.")
121+
122+
try:
123+
feeds = get_non_deprecated_feeds(db_session, feed_stable_ids=feed_stable_ids)
113124
except Exception as error:
114125
logging.error(f"Error retrieving feeds: {error}")
115126
raise Exception(f"Error retrieving feeds: {error}")
@@ -130,7 +141,7 @@ def batch_datasets(request, db_session: Session):
130141
"producer_url": feed.producer_url,
131142
"feed_stable_id": feed.stable_id,
132143
"feed_id": feed.feed_id,
133-
"dataset_id": feed.dataset_id,
144+
"dataset_stable_id": feed.dataset_stable_id,
134145
"dataset_hash": feed.dataset_hash,
135146
"authentication_type": feed.authentication_type,
136147
"authentication_info_url": feed.authentication_info_url,

functions-python/batch_process_dataset/src/main.py

Lines changed: 110 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import base64
1818
import json
19+
import logging
1920
import os
2021
import random
2122
import uuid
@@ -28,17 +29,14 @@
2829
from cloudevents.http import CloudEvent
2930
from google.cloud import storage
3031
from sqlalchemy import func
32+
from sqlalchemy.orm import Session
33+
3134
from shared.common.gcp_utils import create_refresh_materialized_view_task
35+
from shared.database.database import with_db_session
3236
from shared.database_gen.sqlacodegen_models import Gtfsdataset, Gtfsfile
33-
3437
from shared.dataset_service.main import DatasetTraceService, DatasetTrace, Status
35-
from shared.database.database import with_db_session
36-
import logging
37-
3838
from shared.helpers.logger import init_logger, get_logger
39-
from shared.helpers.utils import download_and_get_hash, get_hash_from_file
40-
from sqlalchemy.orm import Session
41-
39+
from shared.helpers.utils import download_and_get_hash, get_hash_from_file, download_from_gcs
4240

4341
init_logger()
4442

@@ -68,6 +66,7 @@ def __init__(
6866
authentication_type,
6967
api_key_parameter_name,
7068
public_hosted_datasets_url,
69+
dataset_stable_id
7170
):
7271
self.logger = get_logger(DatasetProcessor.__name__, feed_stable_id)
7372
self.producer_url = producer_url
@@ -92,6 +91,7 @@ def __init__(
9291

9392
self.init_status = None
9493
self.init_status_additional_data = None
94+
self.dataset_stable_id = dataset_stable_id
9595

9696
@staticmethod
9797
def get_feed_credentials(feed_stable_id) -> str | None:
@@ -132,28 +132,30 @@ def download_content(self, temporary_file_path):
132132
is_zip = zipfile.is_zipfile(temporary_file_path)
133133
return file_hash, is_zip
134134

135-
def upload_file_to_storage(
135+
def upload_files_to_storage(
136136
self,
137137
source_file_path,
138138
dataset_stable_id,
139139
extracted_files_path,
140140
public=True,
141+
skip_dataset_upload=False
141142
):
142143
"""
143-
Uploads a file to the GCP bucket
144+
Uploads the dataset file and extracted files to GCP storage
144145
"""
145146
bucket = storage.Client().get_bucket(self.bucket_name)
146147
target_paths = [
147148
f"{self.feed_stable_id}/latest.zip",
148149
f"{self.feed_stable_id}/{dataset_stable_id}/{dataset_stable_id}.zip",
149150
]
150151
blob = None
151-
for target_path in target_paths:
152-
blob = bucket.blob(target_path)
153-
with open(source_file_path, "rb") as file:
154-
blob.upload_from_file(file)
155-
if public:
156-
blob.make_public()
152+
if not skip_dataset_upload:
153+
for target_path in target_paths:
154+
blob = bucket.blob(target_path)
155+
blob.upload_from_filename(source_file_path)
156+
if public:
157+
blob.make_public()
158+
self.logger.info(f"Uploaded {blob.public_url}")
157159

158160
base_path, _ = os.path.splitext(source_file_path)
159161
extracted_files: List[Gtfsfile] = []
@@ -162,6 +164,7 @@ def upload_file_to_storage(
162164
f"Extracted files path {extracted_files_path} does not exist."
163165
)
164166
return blob, extracted_files
167+
self.logger.info('Processing extracted files from %s', extracted_files_path)
165168
for file_name in os.listdir(extracted_files_path):
166169
file_path = os.path.join(extracted_files_path, file_name)
167170
if os.path.isfile(file_path):
@@ -192,6 +195,7 @@ def upload_dataset(self, public=True) -> DatasetFile or None:
192195
if the dataset hash is different from the latest dataset stored
193196
:return: the file hash and the hosted url as a tuple or None if no upload is required
194197
"""
198+
temp_file_path = None
195199
try:
196200
self.logger.info("Accessing URL %s", self.producer_url)
197201
temp_file_path = self.generate_temp_filename()
@@ -221,11 +225,8 @@ def upload_dataset(self, public=True) -> DatasetFile or None:
221225
dataset_full_path = (
222226
f"{self.feed_stable_id}/{dataset_stable_id}/{dataset_stable_id}.zip"
223227
)
224-
self.logger.info(
225-
f"Creating file: {dataset_full_path}"
226-
f" in bucket {self.bucket_name}"
227-
)
228-
_, extracted_files = self.upload_file_to_storage(
228+
self.logger.info(f"Creating file {dataset_full_path} in bucket {self.bucket_name}")
229+
_, extracted_files = self.upload_files_to_storage(
229230
temp_file_path,
230231
dataset_stable_id,
231232
extracted_files_path,
@@ -249,10 +250,55 @@ def upload_dataset(self, public=True) -> DatasetFile or None:
249250
f"-> {file_sha256_hash}). Not uploading it."
250251
)
251252
finally:
252-
if os.path.exists(temp_file_path):
253+
if temp_file_path and os.path.exists(temp_file_path):
253254
os.remove(temp_file_path)
254255
return None
255256

257+
def process2(self, public=True) -> DatasetFile or None:
258+
"""
259+
Uploads a dataset to a GCP bucket as <feed_stable_id>/latest.zip and
260+
<feed_stable_id>/<feed_stable_id>-<upload_datetime>.zip
261+
if the dataset hash is different from the latest dataset stored
262+
:return: the file hash and the hosted url as a tuple or None if no upload is required
263+
"""
264+
temp_file_path = None
265+
try:
266+
self.logger.info("Accessing URL %s", self.producer_url)
267+
temp_file_path = self.generate_temp_filename()
268+
blob_file_path = f"{self.feed_stable_id}/latest.zip"
269+
download_from_gcs(os.getenv('DATASETS_BUCKET_NAME'), blob_file_path, temp_file_path)
270+
271+
extracted_files_path = self.unzip_files(temp_file_path)
272+
dataset_full_path = (
273+
f"{self.feed_stable_id}/{self.dataset_stable_id}/{self.dataset_stable_id}.zip"
274+
)
275+
self.logger.info(f"Creating file {dataset_full_path} in bucket {self.bucket_name}")
276+
_, extracted_files = self.upload_files_to_storage(
277+
temp_file_path,
278+
self.dataset_stable_id,
279+
extracted_files_path,
280+
public=public,
281+
skip_dataset_upload=True, # Skip the upload of the dataset file
282+
)
283+
284+
dataset_file = DatasetFile(
285+
stable_id=self.dataset_stable_id,
286+
file_sha256_hash=self.latest_hash,
287+
hosted_url=f"{self.public_hosted_datasets_url}/{dataset_full_path}",
288+
extracted_files=extracted_files,
289+
zipped_size=(
290+
os.path.getsize(temp_file_path)
291+
if os.path.exists(temp_file_path)
292+
else None
293+
),
294+
)
295+
self.create_dataset_entities(dataset_file, skip_dataset_creation=True)
296+
finally:
297+
if temp_file_path and os.path.exists(temp_file_path):
298+
os.remove(temp_file_path)
299+
return None
300+
301+
256302
def unzip_files(self, temp_file_path):
257303
extracted_files_path = os.path.join(temp_file_path.split(".")[0], "extracted")
258304
self.logger.info(f"Unzipping files to {extracted_files_path}")
@@ -270,14 +316,14 @@ def generate_temp_filename(self):
270316
Generates a temporary filename
271317
"""
272318
temporary_file_path = (
273-
f"/tmp/{self.feed_stable_id}-{random.randint(0, 1000000)}.zip"
319+
f"/in-memory/{self.feed_stable_id}-{random.randint(0, 1000000)}.zip"
274320
)
275321
return temporary_file_path
276322

277323
@with_db_session
278-
def create_dataset(self, dataset_file: DatasetFile, db_session: Session):
324+
def create_dataset_entities(self, dataset_file: DatasetFile, db_session: Session, skip_dataset_creation=False):
279325
"""
280-
Creates a new dataset in the database
326+
Creates dataset entities in the database
281327
"""
282328
try:
283329
# Check latest version of the dataset
@@ -294,30 +340,40 @@ def create_dataset(self, dataset_file: DatasetFile, db_session: Session):
294340
self.logger.info(
295341
f"[{self.feed_stable_id}] Creating new dataset for feed with stable id {dataset_file.stable_id}."
296342
)
297-
new_dataset = Gtfsdataset(
298-
id=str(uuid.uuid4()),
299-
feed_id=self.feed_id,
300-
stable_id=dataset_file.stable_id,
301-
latest=True,
302-
bounding_box=None,
303-
note=None,
304-
hash=dataset_file.file_sha256_hash,
305-
downloaded_at=func.now(),
306-
hosted_url=dataset_file.hosted_url,
307-
gtfsfiles=(
308-
dataset_file.extracted_files if dataset_file.extracted_files else []
309-
),
310-
zipped_size_bytes=dataset_file.zipped_size,
311-
unzipped_size_bytes=(
312-
sum([ex.file_size_bytes for ex in dataset_file.extracted_files])
313-
if dataset_file.extracted_files
314-
else None
315-
),
316-
)
317-
if latest_dataset:
343+
if not skip_dataset_creation:
344+
dataset = Gtfsdataset(
345+
id=str(uuid.uuid4()),
346+
feed_id=self.feed_id,
347+
stable_id=dataset_file.stable_id,
348+
latest=True,
349+
bounding_box=None,
350+
note=None,
351+
hash=dataset_file.file_sha256_hash,
352+
downloaded_at=func.now(),
353+
hosted_url=dataset_file.hosted_url,
354+
gtfsfiles=(
355+
dataset_file.extracted_files if dataset_file.extracted_files else []
356+
),
357+
zipped_size_bytes=dataset_file.zipped_size,
358+
unzipped_size_bytes=(
359+
sum([ex.file_size_bytes for ex in dataset_file.extracted_files])
360+
if dataset_file.extracted_files
361+
else None
362+
),
363+
)
364+
db_session.add(dataset)
365+
elif skip_dataset_creation and latest_dataset:
366+
latest_dataset.gtfsfiles = dataset_file.extracted_files if dataset_file.extracted_files else []
367+
latest_dataset.zipped_size_bytes = dataset_file.zipped_size
368+
latest_dataset.unzipped_size_bytes = (
369+
sum([ex.file_size_bytes for ex in dataset_file.extracted_files])
370+
if dataset_file.extracted_files
371+
else None
372+
)
373+
374+
if latest_dataset and not skip_dataset_creation:
318375
latest_dataset.latest = False
319376
db_session.add(latest_dataset)
320-
db_session.add(new_dataset)
321377
db_session.commit()
322378
self.logger.info(f"[{self.feed_stable_id}] Dataset created successfully.")
323379

@@ -335,7 +391,7 @@ def process(self) -> DatasetFile or None:
335391
if dataset_file is None:
336392
self.logger.info(f"[{self.feed_stable_id}] No database update required.")
337393
return None
338-
self.create_dataset(dataset_file)
394+
self.create_dataset_entities(dataset_file)
339395
return dataset_file
340396

341397

@@ -374,7 +430,7 @@ def process_dataset(cloud_event: CloudEvent):
374430
producer_url,
375431
feed_stable_id,
376432
feed_id,
377-
dataset_id,
433+
dataset_stable_id,
378434
dataset_hash,
379435
authentication_type,
380436
authentication_info_url,
@@ -409,7 +465,7 @@ def process_dataset(cloud_event: CloudEvent):
409465
trace_service = None
410466
dataset_file: DatasetFile = None
411467
error_message = None
412-
# Extract data from message
468+
# Extract data from message
413469
data = base64.b64decode(cloud_event.data["message"]["data"]).decode()
414470
json_payload = json.loads(data)
415471
stable_id = json_payload["feed_stable_id"]
@@ -445,8 +501,12 @@ def process_dataset(cloud_event: CloudEvent):
445501
int(json_payload["authentication_type"]),
446502
json_payload["api_key_parameter_name"],
447503
public_hosted_datasets_url,
504+
json_payload.get('dataset_stable_id')
448505
)
449-
dataset_file = processor.process()
506+
if json_payload.get("process_files_only", False):
507+
dataset_file = processor.process2()
508+
else:
509+
dataset_file = processor.process()
450510
except Exception as e:
451511
# This makes sure the logger is initialized
452512
logger = get_logger("process_dataset", stable_id if stable_id else "UNKNOWN")

functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_upload_file_to_storage(self):
207207
test_hosted_public_url,
208208
)
209209
dataset_id = faker.Faker().uuid4()
210-
result, _ = processor.upload_file_to_storage(
210+
result, _ = processor.upload_files_to_storage(
211211
source_file_path, dataset_id, extracted_file_path
212212
)
213213
self.assertEqual(result.public_url, public_url)
@@ -358,11 +358,11 @@ def test_process_no_change(self):
358358
)
359359

360360
processor.upload_dataset = MagicMock(return_value=None)
361-
processor.create_dataset = MagicMock()
361+
processor.create_dataset_entities = MagicMock()
362362
result = processor.process()
363363

364364
self.assertIsNone(result)
365-
processor.create_dataset.assert_not_called()
365+
processor.create_dataset_entities.assert_not_called()
366366

367367
@patch("main.DatasetTraceService")
368368
@patch("main.DatasetProcessor")

0 commit comments

Comments
 (0)