Skip to content

Commit ef2d1f4

Browse files
authored
fix: batch process dataset function raises out of memory with large feeds (#1320)
1 parent 8a9ed8c commit ef2d1f4

File tree

11 files changed

+444
-271
lines changed

11 files changed

+444
-271
lines changed

functions-python/batch_datasets/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Batch Datasets
22
This directory contains the GCP serverless function that enqueue all active feeds to download datasets.
3+
The function accepts an option request body to limit the feeds to process, otherwise it processes all active feeds:
4+
```json
5+
{
6+
"feed_stable_ids": ["feed_id_1", "feed_id_2"]
7+
}
8+
```
9+
310
The function publish one Pub/Sub message per active feed with the following format:
411
```json
512
{

functions-python/batch_datasets/src/main.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import uuid
2121
from datetime import datetime
22+
from typing import Optional
2223

2324
import functions_framework
2425
from google.cloud import pubsub_v1
@@ -64,7 +65,9 @@ def publish(publisher: PublisherClient, topic_path: str, data_bytes: bytes) -> F
6465
return publisher.publish(topic_path, data=data_bytes)
6566

6667

67-
def get_non_deprecated_feeds(session: Session):
68+
def get_non_deprecated_feeds(
69+
session: Session, feed_stable_ids: Optional[list[str]] = None
70+
):
6871
"""
6972
Returns a list of non deprecated feeds
7073
:return: list of feeds
@@ -79,14 +82,17 @@ def get_non_deprecated_feeds(session: Session):
7982
Gtfsfeed.authentication_info_url,
8083
Gtfsfeed.api_key_parameter_name,
8184
Gtfsfeed.status,
82-
Gtfsdataset.id.label("dataset_id"),
85+
Gtfsdataset.stable_id.label("dataset_stable_id"),
8386
Gtfsdataset.hash.label("dataset_hash"),
8487
)
8588
.select_from(Gtfsfeed)
8689
.outerjoin(Gtfsdataset, (Gtfsdataset.feed_id == Gtfsfeed.id))
8790
.filter(Gtfsfeed.status != "deprecated")
8891
.filter(or_(Gtfsdataset.id.is_(None), Gtfsdataset.latest.is_(True)))
8992
)
93+
if feed_stable_ids:
94+
# If feed_stable_ids are provided, filter the query by stable IDs
95+
query = query.filter(Gtfsfeed.stable_id.in_(feed_stable_ids))
9096
# Limit the query to 10 feeds (or FEEDS_LIMIT param) for testing purposes and lower environments
9197
if os.getenv("ENVIRONMENT", "").lower() in ("dev", "test", "qa"):
9298
limit = os.getenv("FEEDS_LIMIT")
@@ -108,8 +114,17 @@ def batch_datasets(request, db_session: Session):
108114
:param db_session: database session object
109115
:return: HTTP response object
110116
"""
117+
feed_stable_ids = None
111118
try:
112-
feeds = get_non_deprecated_feeds(db_session)
119+
request_json = request.get_json()
120+
feed_stable_ids = request_json.get("feed_stable_ids") if request_json else None
121+
except Exception:
122+
logging.info(
123+
"No feed_stable_ids provided in the request, processing all feeds."
124+
)
125+
126+
try:
127+
feeds = get_non_deprecated_feeds(db_session, feed_stable_ids=feed_stable_ids)
113128
except Exception as error:
114129
logging.error(f"Error retrieving feeds: {error}")
115130
raise Exception(f"Error retrieving feeds: {error}")
@@ -130,7 +145,7 @@ def batch_datasets(request, db_session: Session):
130145
"producer_url": feed.producer_url,
131146
"feed_stable_id": feed.stable_id,
132147
"feed_id": feed.feed_id,
133-
"dataset_id": feed.dataset_id,
148+
"dataset_stable_id": feed.dataset_stable_id,
134149
"dataset_hash": feed.dataset_hash,
135150
"authentication_type": feed.authentication_type,
136151
"authentication_info_url": feed.authentication_info_url,

functions-python/batch_datasets/tests/test_batch_datasets_main.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,47 @@ def test_batch_datasets(mock_client, mock_publish, db_session):
5454
"shared.dataset_service.main.BatchExecutionService.save",
5555
return_value=None,
5656
):
57-
batch_datasets(Mock())
57+
mock_request = MagicMock()
58+
mock_request.get_json = MagicMock(return_value={})
59+
batch_datasets(mock_request)
60+
assert mock_publish.call_count == 5
61+
# loop over mock_publish.call_args_list and check that the stable_id of the feed is in the list of
62+
# active feeds
63+
for i in range(3):
64+
message = json.loads(
65+
mock_publish.call_args_list[i][0][2].decode("utf-8")
66+
)
67+
assert message["feed_stable_id"] in [feed.stable_id for feed in feeds]
68+
69+
70+
@mock.patch.dict(
71+
os.environ,
72+
{
73+
"FEEDS_DATABASE_URL": default_db_url,
74+
"FEEDS_PUBSUB_TOPIC_NAME": "test_topic",
75+
"ENVIRONMENT": "test",
76+
"FEEDS_LIMIT": "5",
77+
},
78+
)
79+
@patch("main.publish")
80+
@patch("main.get_pubsub_client")
81+
@with_db_session(db_url=default_db_url)
82+
def test_batch_datasets_w_feed_ids(mock_client, mock_publish, db_session):
83+
mock_client.return_value = MagicMock()
84+
feeds = get_non_deprecated_feeds(db_session)
85+
with patch(
86+
"shared.dataset_service.main.BatchExecutionService.__init__",
87+
return_value=None,
88+
):
89+
with patch(
90+
"shared.dataset_service.main.BatchExecutionService.save",
91+
return_value=None,
92+
):
93+
mock_request = MagicMock()
94+
mock_request.get_json = MagicMock(
95+
return_value={"feed_stable_ids": [feed.stable_id for feed in feeds]}
96+
)
97+
batch_datasets(mock_request)
5898
assert mock_publish.call_count == 5
5999
# loop over mock_publish.call_args_list and check that the stable_id of the feed is in the list of
60100
# active feeds

functions-python/batch_process_dataset/function_config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"description": "Process datasets from the feed passed in the Pub/Sub event",
44
"entry_point": "process_dataset",
55
"timeout": 540,
6-
"memory": "2Gi",
6+
"memory": "8Gi",
77
"trigger_http": true,
88
"include_folders": ["helpers", "dataset_service"],
99
"include_api_folders": ["database_gen", "database", "common"],
@@ -20,5 +20,5 @@
2020
"max_instance_request_concurrency": 1,
2121
"max_instance_count": 5,
2222
"min_instance_count": 0,
23-
"available_cpu": 1
23+
"available_cpu": 2
2424
}

0 commit comments

Comments
 (0)