Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,16 @@ def get_feeds_data(
return [d for d in data if d["stops_url"]]


def parse_request_parameters(request: flask.Request) -> Tuple[List[str], bool]:
"""Parse the request parameters to get the country codes and whether to include only unprocessed feeds."""
def parse_request_parameters(request: flask.Request) -> Tuple[List[str], bool, bool]:
"""
Parse the request parameters.

Returns:
Tuple[List[str], bool, bool]: A tuple containing:
- country_codes: List of country codes to filter feeds
- include_only_unprocessed: Whether to include only unprocessed feeds
- use_cache: Whether to use cache for reverse geolocation
"""
json_request = request.get_json()
country_codes = json_request.get("country_codes", "").split(",")
country_codes = [code.strip().upper() for code in country_codes if code]
Expand All @@ -78,13 +86,16 @@ def parse_request_parameters(request: flask.Request) -> Tuple[List[str], bool]:
include_only_unprocessed = (
json_request.get("include_only_unprocessed", True) is True
)
return country_codes, include_only_unprocessed
use_cache = bool(json_request.get("use_cache", True))
return country_codes, include_only_unprocessed, use_cache


def reverse_geolocation_batch(request: flask.Request) -> Tuple[str, int]:
"""Batch function to trigger reverse geolocation for feeds."""
try:
country_codes, include_only_unprocessed = parse_request_parameters(request)
country_codes, include_only_unprocessed, use_cache = parse_request_parameters(
request
)
feeds_data = get_feeds_data(country_codes, include_only_unprocessed)
logging.info("Valid feeds with latest dataset: %s", len(feeds_data))

Expand All @@ -93,6 +104,7 @@ def reverse_geolocation_batch(request: flask.Request) -> Tuple[str, int]:
stable_id=feed["stable_id"],
dataset_id=feed["dataset_id"],
stops_url=feed["stops_url"],
use_cache=use_cache,
)
return f"Batch function triggered for {len(feeds_data)} feeds.", 200
except Exception as e:
Expand All @@ -104,13 +116,19 @@ def create_http_processor_task(
stable_id: str,
dataset_id: str,
stops_url: str,
use_cache: bool = True,
) -> None:
"""
Create a task to process a group of points.
"""
client = tasks_v2.CloudTasksClient()
body = json.dumps(
{"stable_id": stable_id, "stops_url": stops_url, "dataset_id": dataset_id}
{
"stable_id": stable_id,
"stops_url": stops_url,
"dataset_id": dataset_id,
"use_cache": use_cache,
}
).encode()
queue_name = os.getenv("QUEUE_NAME")
project_id = os.getenv("PROJECT_ID")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,12 @@ def test_parse_request_parameters(self):
}.get(value, default)
from reverse_geolocation_batch import parse_request_parameters

country_codes, include_only_unprocessed = parse_request_parameters(request)
country_codes, include_only_unprocessed, use_cache = parse_request_parameters(
request
)
self.assertEqual(["CA", "US"], country_codes)
self.assertTrue(include_only_unprocessed)
self.assertTrue(use_cache)

with pytest.raises(ValueError):
request.get_json.return_value.get = lambda value, default: {
Expand All @@ -121,7 +124,7 @@ def test_reverse_geolocation_batch(self, mock_parse_request, mock_get_feeds, _):
from reverse_geolocation_batch import reverse_geolocation_batch

request = MagicMock()
mock_parse_request.return_value = ["CA", "US"]
mock_parse_request.return_value = (["CA", "US"], False, False)
mock_get_feeds.return_value = [
{
"stable_id": "test_feed",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"JP": [7],
"CA": [6, 8],
"FR": []
"FR": [8],
"US": [5, 8]
}