Skip to content

Commit 509935c

Browse files
committed
[fun] location extraction process
1 parent 5a98509 commit 509935c

33 files changed

+1651
-752
lines changed

api/.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
22
max-line-length = 120
3-
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,src/feeds_gen,src/database_gen
3+
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,src/feeds_gen,src/database_gen,src/shared/database_gen
44
# Ignored because conflict with black
55
extend-ignore = E203

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import List, Union, TypeVar
2+
from typing import List, Union, TypeVar, Optional
33

44
from sqlalchemy import or_
55
from sqlalchemy import select
@@ -43,7 +43,6 @@
4343
from utils.date_utils import valid_iso_date
4444
from utils.location_translation import (
4545
create_location_translation_object,
46-
LocationTranslation,
4746
get_feeds_location_translations,
4847
)
4948
from utils.logger import Logger
@@ -129,42 +128,39 @@ def get_feeds(
129128
@with_db_session
130129
def get_gtfs_feed(self, id: str, db_session: Session) -> GtfsFeed:
131130
"""Get the specified gtfs feed from the Mobility Database."""
132-
feed, translations = self._get_gtfs_feed(id, db_session)
131+
feed = self._get_gtfs_feed(id, db_session)
133132
if feed:
134-
return GtfsFeedImpl.from_orm(feed, translations)
133+
return GtfsFeedImpl.from_orm(feed)
135134
else:
136135
raise_http_error(404, gtfs_feed_not_found.format(id))
137136

138137
@staticmethod
139-
def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
138+
def _get_gtfs_feed(stable_id: str, db_session: Session) -> Optional[Gtfsfeed]:
140139
results = (
141140
FeedFilter(
142141
stable_id=stable_id,
143142
status=None,
144143
provider__ilike=None,
145144
producer_url__ilike=None,
146145
)
147-
.filter(db_session.query(Gtfsfeed, t_location_with_translations_en))
146+
.filter(db_session.query(Gtfsfeed))
148147
.filter(
149148
or_(
150149
Gtfsfeed.operational_status == None, # noqa: E711
151150
Gtfsfeed.operational_status != "wip",
152151
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
153152
)
154153
)
155-
.outerjoin(Location, Feed.locations)
156-
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
157154
.options(
158155
joinedload(Gtfsfeed.gtfsdatasets)
159156
.joinedload(Gtfsdataset.validation_reports)
160157
.joinedload(Validationreport.notices),
161158
*BasicFeedImpl.get_joinedload_options(),
162159
)
163160
).all()
164-
if len(results) > 0 and results[0].Gtfsfeed:
165-
translations = {result[1]: create_location_translation_object(result) for result in results}
166-
return results[0].Gtfsfeed, translations
167-
return None, {}
161+
if len(results) == 0:
162+
return None
163+
return results[0]
168164

169165
@with_db_session
170166
def get_gtfs_feed_datasets(
@@ -389,8 +385,8 @@ def _get_response(feed_query: Query, impl_cls: type[T], db_session: "Session") -
389385
@with_db_session
390386
def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]:
391387
"""Get a list of GTFS Realtime related to a GTFS feed."""
392-
feed, translations = self._get_gtfs_feed(id, db_session)
388+
feed = self._get_gtfs_feed(id, db_session)
393389
if feed:
394-
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds]
390+
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed) for gtfs_rt_feed in feed.gtfs_rt_feeds]
395391
else:
396392
raise_http_error(404, gtfs_feed_not_found.format(id))

functions-python/.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
22
max-line-length = 120
3-
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,.*,database_gen,feeds_operations_gen
3+
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,.*,database_gen,feeds_operations_gen,shared
44
# Ignored because conflict with black
55
extend-ignore = E203

functions-python/extract_location/src/main.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77

88
import functions_framework
99
from cloudevents.http import CloudEvent
10-
from google.cloud import pubsub_v1
11-
from sqlalchemy import or_
12-
from sqlalchemy.orm import joinedload
13-
1410
from shared.database_gen.sqlacodegen_models import Gtfsdataset
1511
from shared.dataset_service.main import (
1612
DatasetTraceService,
@@ -22,10 +18,14 @@
2218
from shared.helpers.database import Database
2319
from shared.helpers.logger import Logger
2420
from shared.helpers.parser import jsonify_pubsub
21+
from sqlalchemy import or_
22+
from sqlalchemy.orm import joinedload
23+
2524
from bounding_box.bounding_box_extractor import (
2625
create_polygon_wkt_element,
2726
update_dataset_bounding_box,
2827
)
28+
from shared.helpers.pub_sub import publish_messages
2929
from reverse_geolocation.location_extractor import update_location, reverse_coords
3030
from stops_utils import get_gtfs_feed_bounds_and_points
3131

@@ -232,11 +232,5 @@ def extract_location_batch(_):
232232
pass
233233

234234
# Trigger update location for each dataset by publishing to Pub/Sub
235-
publisher = pubsub_v1.PublisherClient()
236-
topic_path = publisher.topic_path(os.getenv("PROJECT_ID"), pubsub_topic_name)
237-
for data in datasets_data:
238-
message_data = json.dumps(data).encode("utf-8")
239-
future = publisher.publish(topic_path, message_data)
240-
logging.info(f"Published message to Pub/Sub with ID: {future.result()}")
241-
235+
publish_messages(datasets_data, os.getenv("PROJECT_ID"), pubsub_topic_name)
242236
return f"Batch function triggered for {len(datasets_data)} datasets.", 200

functions-python/helpers/database.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import threading
2121
from typing import Optional
22+
from functools import wraps
2223

2324
from sqlalchemy import create_engine, text, event, Engine
2425
from sqlalchemy.orm import sessionmaker, Session, mapper, class_mapper
@@ -85,7 +86,7 @@ def mapper_configure_listener(mapper, class_):
8586
event.listen(mapper, "mapper_configured", mapper_configure_listener)
8687

8788

88-
def with_db_session(func):
89+
def with_db_session(_func=None, *, echo=True):
8990
"""
9091
Decorator to handle the session management for the decorated function.
9192
@@ -103,16 +104,26 @@ def with_db_session(func):
103104
- If 'db_session' is already provided, it simply calls the decorated function with the existing session.
104105
"""
105106

106-
def wrapper(*args, **kwargs):
107-
db_session = kwargs.get("db_session")
108-
if db_session is None:
109-
db = Database()
110-
with db.start_db_session() as session:
111-
kwargs["db_session"] = session
107+
def decorator(func):
108+
@wraps(func)
109+
def wrapper(*args, **kwargs):
110+
db_session = kwargs.get("db_session")
111+
112+
if db_session is None:
113+
db = Database()
114+
with db.start_db_session(echo=echo) as session:
115+
kwargs["db_session"] = session
116+
return func(*args, **kwargs)
117+
else:
112118
return func(*args, **kwargs)
113-
return func(*args, **kwargs)
114119

115-
return wrapper
120+
return wrapper
121+
122+
# Allow decorator to be used with or without parentheses
123+
if _func is None:
124+
return decorator
125+
else:
126+
return decorator(_func)
116127

117128

118129
class Database:

functions-python/helpers/logger.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import os
1617

1718
import google.cloud.logging
1819
from google.cloud.logging_v2 import Client
@@ -47,6 +48,8 @@ def init_logger() -> Client:
4748
"""
4849
Initializes the logger
4950
"""
51+
if os.getenv("DEBUG", "False") == "True":
52+
return None
5053
client = google.cloud.logging.Client()
5154
client.get_default_handler()
5255
client.setup_logging()

functions-python/helpers/pub_sub.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import json
1617
import uuid
18+
from typing import Dict, List
1719

1820
from google.cloud import pubsub_v1
1921
from google.cloud.pubsub_v1 import PublisherClient
@@ -43,3 +45,15 @@ def get_execution_id(request, prefix: str) -> str:
4345
trace_id = request.headers.get("X-Cloud-Trace-Context")
4446
execution_id = f"{prefix}-{trace_id}" if trace_id else f"{prefix}-{uuid.uuid4()}"
4547
return execution_id
48+
49+
50+
def publish_messages(data: List[Dict], project_id, topic_name) -> None:
51+
"""
52+
Publishes the given data to the Pub/Sub topic.
53+
"""
54+
publisher = get_pubsub_client()
55+
topic_path = publisher.topic_path(project_id, topic_name)
56+
for element in data:
57+
message_data = json.dumps(element).encode("utf-8")
58+
future = publish(publisher, topic_path, message_data)
59+
print(f"Published message to Pub/Sub with ID: {future.result()}")

functions-python/helpers/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import hashlib
1818
import logging
1919
import os
20+
from typing import List
2021

2122
import requests
2223
import urllib3
@@ -40,6 +41,37 @@ def create_bucket(bucket_name):
4041
logging.info(f"Bucket {bucket_name} already exists.")
4142

4243

44+
def cors_configuration(bucket_name, origin=["*"], method=["GET"], max_age_seconds=3600):
45+
"""Set a bucket's CORS policies configuration."""
46+
storage_client = storage.Client()
47+
bucket = storage_client.get_bucket(bucket_name)
48+
bucket.cors = [
49+
{
50+
"origin": origin,
51+
"responseHeader": [
52+
"Content-Type",
53+
"Access-Control-Allow-Origin",
54+
"x-goog-resumable",
55+
],
56+
"method": method,
57+
"maxAgeSeconds": max_age_seconds,
58+
}
59+
]
60+
bucket.patch()
61+
logging.info(
62+
f"CORS policy for {bucket_name} set to allow requests from any origin."
63+
)
64+
65+
66+
def list_blobs(bucket_name: str, prefix: str = "", suffix="") -> List[storage.Blob]:
67+
"""
68+
List all files in a GCP bucket with the given prefix and suffix.
69+
"""
70+
storage_client = storage.Client()
71+
blobs = list(storage_client.list_blobs(bucket_name, prefix=prefix))
72+
return [blob for blob in blobs if blob.name.endswith(suffix)]
73+
74+
4375
def download_url_content(url, with_retry=False):
4476
"""
4577
Downloads the content of a URL
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[run]
2+
omit =
3+
*/test*/*
4+
*/helpers/*
5+
*/database_gen/*
6+
7+
[report]
8+
exclude_lines =
9+
if __name__ == .__main__.:
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# todo

0 commit comments

Comments
 (0)