Skip to content

Commit 6f9d0e9

Browse files
authored
feat: reverse geolocation per polygon strategy (#1318)
1 parent 80f7b80 commit 6f9d0e9

26 files changed

+2202
-692
lines changed

api/src/shared/common/gcp_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
import json
12
import logging
23
import os
34
from google.cloud import tasks_v2
45
from google.protobuf.timestamp_pb2 import Timestamp
56

7+
REFRESH_VIEW_TASK_EXECUTOR_BODY = json.dumps(
8+
{"task": "refresh_materialized_view", "payload": {"dry_run": False}}
9+
).encode()
10+
611

712
def create_refresh_materialized_view_task():
813
"""
@@ -39,20 +44,19 @@ def create_refresh_materialized_view_task():
3944
logging.debug("Queue name from env: %s", queue)
4045
gcp_region = os.getenv("GCP_REGION")
4146
environment_name = os.getenv("ENVIRONMENT")
42-
url = f"https://{gcp_region}-" f"{project}.cloudfunctions.net/" f"tasks-executor-{environment_name}"
43-
47+
url = f"https://{gcp_region}-" f"{project}.cloudfunctions.net/" f"tasks_executor-{environment_name}"
4448
# Enqueue the task
4549
try:
4650
create_http_task_with_name(
4751
client=tasks_v2.CloudTasksClient(),
48-
body=b"",
52+
body=REFRESH_VIEW_TASK_EXECUTOR_BODY,
4953
url=url,
5054
project_id=project,
5155
gcp_region=gcp_region,
5256
queue_name=queue,
5357
task_name=task_name,
5458
task_time=proto_time,
55-
http_method=tasks_v2.HttpMethod.GET,
59+
http_method=tasks_v2.HttpMethod.POST,
5660
)
5761
logging.info("Scheduled refresh materialized view task for %s", task_name)
5862
return {"message": "Refresh task for %s scheduled." % task_name}, 200
@@ -95,10 +99,12 @@ def create_http_task_with_name(
9599
headers={"Content-Type": "application/json"},
96100
),
97101
)
98-
logging.info("Task created with task_name: %s", task_name)
99102
try:
100103
response = client.create_task(parent=parent, task=task)
104+
logging.info("Task created with task_name: %s", task_name)
101105
except Exception as e:
102-
logging.error("Error creating task: %s", e)
103-
logging.error("response: %s", response)
104-
logging.info("Successfully created task in create_http_task_with_name")
106+
if "Requested entity already exists" in str(e):
107+
logging.info("Task already exists for %s, skipping.", task_name)
108+
else:
109+
logging.error("Error creating task: %s", e)
110+
logging.error("response: %s", response)

functions-python/batch_datasets/src/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from sqlalchemy.orm import Session
3030

3131
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsdataset
32-
from shared.dataset_service.main import BatchExecutionService, BatchExecution
32+
from shared.dataset_service.dataset_service_commons import BatchExecution
33+
from shared.dataset_service.main import BatchExecutionService
3334
from shared.database.database import with_db_session
3435
from shared.helpers.logger import init_logger
3536

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from dataclasses import dataclass
2+
from datetime import datetime
3+
from enum import Enum
4+
from typing import Optional
5+
6+
7+
# Status of the dataset trace
8+
class Status(Enum):
9+
FAILED = "FAILED"
10+
SUCCESS = "SUCCESS"
11+
PUBLISHED = "PUBLISHED"
12+
NOT_PUBLISHED = "NOT_PUBLISHED"
13+
PROCESSING = "PROCESSING"
14+
15+
16+
# Stage of the pipeline
17+
class PipelineStage(Enum):
18+
DATASET_PROCESSING = "DATASET_PROCESSING"
19+
LOCATION_EXTRACTION = "LOCATION_EXTRACTION"
20+
GBFS_VALIDATION = "GBFS_VALIDATION"
21+
22+
23+
# Dataset trace class to store the trace of a dataset
24+
@dataclass
25+
class DatasetTrace:
26+
stable_id: str
27+
status: Status
28+
timestamp: datetime
29+
dataset_id: Optional[str] = None
30+
trace_id: Optional[str] = None
31+
execution_id: Optional[str] = None
32+
file_sha256_hash: Optional[str] = None
33+
hosted_url: Optional[str] = None
34+
pipeline_stage: PipelineStage = PipelineStage.DATASET_PROCESSING
35+
error_message: Optional[str] = None
36+
37+
38+
# Batch execution class to store the trace of a batch execution
39+
@dataclass
40+
class BatchExecution:
41+
execution_id: str
42+
timestamp: datetime
43+
feeds_total: int

functions-python/dataset_service/main.py

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,30 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import importlib
1617
import logging
1718
import uuid
18-
from datetime import datetime
19-
from enum import Enum
20-
from dataclasses import dataclass, asdict
21-
from typing import Optional, Final
19+
from dataclasses import asdict
20+
from typing import Final
2221
from google.cloud import datastore
2322
from google.cloud.datastore import Client
2423

24+
# This allows the module to be run as a script or imported as a module
25+
if __package__ is None or __package__ == "":
26+
import os
27+
import sys
28+
29+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
30+
import dataset_service_commons
31+
else:
32+
dataset_service_commons = importlib.import_module(
33+
".dataset_service_commons", package=__package__
34+
)
35+
36+
Status = dataset_service_commons.Status
37+
PipelineStage = dataset_service_commons.PipelineStage
38+
BatchExecution = dataset_service_commons.BatchExecution
39+
DatasetTrace = dataset_service_commons.DatasetTrace
2540

2641
# This files contains the dataset trace and batch execution models and services.
2742
# The dataset trace is used to store the trace of a dataset and the batch execution
@@ -30,45 +45,6 @@
3045
# The persistent layer used is Google Cloud Datastore.
3146

3247

33-
# Status of the dataset trace
34-
class Status(Enum):
35-
FAILED = "FAILED"
36-
SUCCESS = "SUCCESS"
37-
PUBLISHED = "PUBLISHED"
38-
NOT_PUBLISHED = "NOT_PUBLISHED"
39-
PROCESSING = "PROCESSING"
40-
41-
42-
# Stage of the pipeline
43-
class PipelineStage(Enum):
44-
DATASET_PROCESSING = "DATASET_PROCESSING"
45-
LOCATION_EXTRACTION = "LOCATION_EXTRACTION"
46-
GBFS_VALIDATION = "GBFS_VALIDATION"
47-
48-
49-
# Dataset trace class to store the trace of a dataset
50-
@dataclass
51-
class DatasetTrace:
52-
stable_id: str
53-
status: Status
54-
timestamp: datetime
55-
dataset_id: Optional[str] = None
56-
trace_id: Optional[str] = None
57-
execution_id: Optional[str] = None
58-
file_sha256_hash: Optional[str] = None
59-
hosted_url: Optional[str] = None
60-
pipeline_stage: PipelineStage = PipelineStage.DATASET_PROCESSING
61-
error_message: Optional[str] = None
62-
63-
64-
# Batch execution class to store the trace of a batch execution
65-
@dataclass
66-
class BatchExecution:
67-
execution_id: str
68-
timestamp: datetime
69-
feeds_total: int
70-
71-
7248
dataset_trace_collection: Final[str] = "dataset_trace"
7349
batch_execution_collection: Final[str] = "batch_execution"
7450

functions-python/dataset_service/tests/__init__.py

Whitespace-only changes.

functions-python/dataset_service/tests/test_dataset_service.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,8 @@
22
from datetime import datetime
33
from unittest.mock import patch, MagicMock
44

5-
from main import (
6-
DatasetTrace,
7-
DatasetTraceService,
8-
Status,
9-
BatchExecutionService,
10-
BatchExecution,
11-
)
5+
from dataset_service_commons import DatasetTrace, Status, BatchExecution
6+
from main import DatasetTraceService, BatchExecutionService
127

138

149
class TestDatasetService(unittest.TestCase):

functions-python/helpers/.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ omit =
44
database.py
55
*/database_gen/*
66
*/dataset_service/*
7+
*/shared/common/*
78

89
[report]
910
exclude_lines =

functions-python/helpers/locations.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from enum import Enum
22
from typing import Dict, Optional
3+
4+
from geoalchemy2 import WKTElement
35
from sqlalchemy.orm import Session
6+
from sqlalchemy import func, cast
7+
from geoalchemy2.types import Geography
8+
49
import pycountry
5-
from shared.database_gen.sqlacodegen_models import Feed, Location
10+
from shared.database_gen.sqlacodegen_models import Feed, Location, Geopolygon
611
import logging
712

813

@@ -11,8 +16,14 @@ class ReverseGeocodingStrategy(str, Enum):
1116
Enum for reverse geocoding strategies.
1217
"""
1318

19+
# Per point strategy uses point-in-polygon to find the location for each point
20+
# It queries the database for each point, which can be slow for large datasets
1421
PER_POINT = "per-point"
1522

23+
# Per polygon strategy uses point-in-polygon to find the location for each point
24+
# It queries the database for each polygon, which can be faster for large datasets
25+
PER_POLYGON = "per-polygon"
26+
1627

1728
def get_country_code(country_name: str) -> Optional[str]:
1829
"""
@@ -133,3 +144,90 @@ def translate_feed_locations(feed: Feed, location_translations: Dict):
133144
if location_translation["country_translation"]
134145
else location.country
135146
)
147+
148+
149+
def to_shapely(g):
150+
"""
151+
Convert a GeoAlchemy WKB/WKT element or WKT string into a Shapely geometry.
152+
If it's already a Shapely geometry, return it as-is.
153+
"""
154+
# Import here to avoid adding unnecessary dependencies if not used to GCP functions
155+
from shapely import wkt as shapely_wkt
156+
from geoalchemy2 import WKTElement, WKBElement
157+
from geoalchemy2.shape import to_shape
158+
159+
if isinstance(g, WKBElement):
160+
return to_shape(g)
161+
if isinstance(g, WKTElement):
162+
return shapely_wkt.loads(g.data)
163+
if isinstance(g, str):
164+
# assume WKT
165+
return shapely_wkt.loads(g)
166+
return g # assume already shapely
167+
168+
169+
def select_highest_level_polygon(geopolygons: list[Geopolygon]) -> Optional[Geopolygon]:
170+
"""
171+
Select the geopolygon with the highest admin_level from a list of geopolygons.
172+
Admin levels are compared, with NULL treated as the lowest priority.
173+
"""
174+
if not geopolygons:
175+
return None
176+
# Treat NULL admin_level as the lowest priority
177+
return max(
178+
geopolygons, key=lambda g: (-1 if g.admin_level is None else g.admin_level)
179+
)
180+
181+
182+
def select_lowest_level_polygon(geopolygons: list[Geopolygon]) -> Optional[Geopolygon]:
183+
"""
184+
Select the geopolygon with the lowest admin_level from a list of geopolygons.
185+
Admin levels are compared, with NULL treated as the lowest priority.
186+
"""
187+
if not geopolygons:
188+
return None
189+
# Treat NULL admin_level as the lowest priority
190+
return min(
191+
geopolygons, key=lambda g: (100 if g.admin_level is None else g.admin_level)
192+
)
193+
194+
195+
def get_country_code_from_polygons(geopolygons: list[Geopolygon]) -> Optional[str]:
196+
"""
197+
Given a list of polygon GeoJSON-like features (each with 'properties'),
198+
return the country code (ISO 3166-1 alpha-2) from the most likely polygon.
199+
200+
Args:
201+
polygons: List of dicts, each must have 'properties' with
202+
'admin_level' and 'iso_3166_1_code'
203+
204+
Returns:
205+
A two-letter country code string or None if not found
206+
"""
207+
country_polygons = [g for g in geopolygons if g.iso_3166_1_code]
208+
if not country_polygons:
209+
return None
210+
211+
# Prefer the one with the lowest admin_level (most local)
212+
lowest_admin_level_polygon = select_lowest_level_polygon(country_polygons)
213+
return lowest_admin_level_polygon.iso_3166_1_code
214+
215+
216+
def get_geopolygons_covers(stop_point: WKTElement, db_session: Session):
217+
"""
218+
Get all geopolygons that cover a given point using BigQuery-compatible semantics.
219+
"""
220+
# BigQuery-compatible point-in-polygon (geodesic + border-inclusive)
221+
geopolygons = (
222+
db_session.query(Geopolygon)
223+
# optional prefilter to use your GiST index on geometry (fast)
224+
.filter(func.ST_Intersects(Geopolygon.geometry, stop_point))
225+
# exact check matching BigQuery's GEOGRAPHY semantics
226+
.filter(
227+
func.ST_Covers(
228+
cast(Geopolygon.geometry, Geography(srid=4326)),
229+
cast(stop_point, Geography(srid=4326)),
230+
)
231+
).all()
232+
)
233+
return geopolygons

functions-python/helpers/logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import logging
1818
import threading
1919

20-
import google.cloud.logging
21-
2220
from shared.common.logging_utils import get_env_logging_level
2321

2422

@@ -61,6 +59,8 @@ def init_logger():
6159
if _logging_initialized:
6260
return
6361
try:
62+
import google.cloud.logging
63+
6464
client = google.cloud.logging.Client()
6565
client.setup_logging()
6666
except Exception as error:

functions-python/helpers/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ google-cloud-firestore
2727
google-cloud-bigquery
2828

2929
# Additional package
30-
pycountry
30+
pycountry
31+
shapely

0 commit comments

Comments
 (0)