Skip to content

Commit 800d355

Browse files
authored
feat: extract all stop locations from stops.txt (#921)
1 parent ed581df commit 800d355

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+3884
-120
lines changed

api/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ extend-exclude = '''
1919
| build
2020
| dist
2121
| src/feeds_gen/*
22-
| src/database_gen/*
22+
| src/shared/database_gen/*
2323
)/
2424
)
2525
'''

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import List, Union, TypeVar
2+
from typing import List, Union, TypeVar, Optional
33

44
from sqlalchemy import or_
55
from sqlalchemy import select
@@ -44,7 +44,6 @@
4444
from utils.date_utils import valid_iso_date
4545
from utils.location_translation import (
4646
create_location_translation_object,
47-
LocationTranslation,
4847
get_feeds_location_translations,
4948
)
5049
from utils.logger import Logger
@@ -130,42 +129,39 @@ def get_feeds(
130129
@with_db_session
131130
def get_gtfs_feed(self, id: str, db_session: Session) -> GtfsFeed:
132131
"""Get the specified gtfs feed from the Mobility Database."""
133-
feed, translations = self._get_gtfs_feed(id, db_session)
132+
feed = self._get_gtfs_feed(id, db_session)
134133
if feed:
135-
return GtfsFeedImpl.from_orm(feed, translations)
134+
return GtfsFeedImpl.from_orm(feed)
136135
else:
137136
raise_http_error(404, gtfs_feed_not_found.format(id))
138137

139138
@staticmethod
140-
def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
139+
def _get_gtfs_feed(stable_id: str, db_session: Session) -> Optional[Gtfsfeed]:
141140
results = (
142141
FeedFilter(
143142
stable_id=stable_id,
144143
status=None,
145144
provider__ilike=None,
146145
producer_url__ilike=None,
147146
)
148-
.filter(db_session.query(Gtfsfeed, t_location_with_translations_en))
147+
.filter(db_session.query(Gtfsfeed))
149148
.filter(
150149
or_(
151150
Gtfsfeed.operational_status == None, # noqa: E711
152151
Gtfsfeed.operational_status != "wip",
153152
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
154153
)
155154
)
156-
.outerjoin(Location, Feed.locations)
157-
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
158155
.options(
159156
joinedload(Gtfsfeed.gtfsdatasets)
160157
.joinedload(Gtfsdataset.validation_reports)
161158
.joinedload(Validationreport.notices),
162159
*get_joinedload_options(),
163160
)
164161
).all()
165-
if len(results) > 0 and results[0].Gtfsfeed:
166-
translations = {result[1]: create_location_translation_object(result) for result in results}
167-
return results[0].Gtfsfeed, translations
168-
return None, {}
162+
if len(results) == 0:
163+
return None
164+
return results[0]
169165

170166
@with_db_session
171167
def get_gtfs_feed_datasets(
@@ -393,8 +389,8 @@ def _get_response(feed_query: Query, impl_cls: type[T], db_session: "Session") -
393389
@with_db_session
394390
def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]:
395391
"""Get a list of GTFS Realtime related to a GTFS feed."""
396-
feed, translations = self._get_gtfs_feed(id, db_session)
392+
feed = self._get_gtfs_feed(id, db_session)
397393
if feed:
398-
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds]
394+
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed) for gtfs_rt_feed in feed.gtfs_rt_feeds]
399395
else:
400396
raise_http_error(404, gtfs_feed_not_found.format(id))

functions-python/.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
22
max-line-length = 120
3-
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,.*,database_gen,feeds_operations_gen
3+
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,.*,database_gen,feeds_operations_gen,shared
44
# Ignored because conflict with black
55
extend-ignore = E203

functions-python/extract_location/src/main.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77

88
import functions_framework
99
from cloudevents.http import CloudEvent
10-
from google.cloud import pubsub_v1
11-
from sqlalchemy import or_
12-
from sqlalchemy.orm import joinedload
13-
1410
from shared.database_gen.sqlacodegen_models import Gtfsdataset
1511
from shared.dataset_service.main import (
1612
DatasetTraceService,
@@ -22,10 +18,14 @@
2218
from shared.helpers.database import Database
2319
from shared.helpers.logger import Logger
2420
from shared.helpers.parser import jsonify_pubsub
21+
from sqlalchemy import or_
22+
from sqlalchemy.orm import joinedload
23+
2524
from bounding_box.bounding_box_extractor import (
2625
create_polygon_wkt_element,
2726
update_dataset_bounding_box,
2827
)
28+
from shared.helpers.pub_sub import publish_messages
2929
from reverse_geolocation.location_extractor import update_location, reverse_coords
3030
from stops_utils import get_gtfs_feed_bounds_and_points
3131

@@ -232,11 +232,5 @@ def extract_location_batch(_):
232232
pass
233233

234234
# Trigger update location for each dataset by publishing to Pub/Sub
235-
publisher = pubsub_v1.PublisherClient()
236-
topic_path = publisher.topic_path(os.getenv("PROJECT_ID"), pubsub_topic_name)
237-
for data in datasets_data:
238-
message_data = json.dumps(data).encode("utf-8")
239-
future = publisher.publish(topic_path, message_data)
240-
logging.info(f"Published message to Pub/Sub with ID: {future.result()}")
241-
235+
publish_messages(datasets_data, os.getenv("PROJECT_ID"), pubsub_topic_name)
242236
return f"Batch function triggered for {len(datasets_data)} datasets.", 200

functions-python/extract_location/tests/test_location_extraction.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,11 @@ def test_extract_location_exception_2(
267267
},
268268
)
269269
@patch("main.Database")
270-
@patch("main.pubsub_v1.PublisherClient")
270+
@patch("main.publish_messages")
271271
@patch("main.Logger")
272272
@patch("uuid.uuid4")
273273
def test_extract_location_batch(
274-
self, uuid_mock, logger_mock, publisher_client_mock, database_mock
274+
self, uuid_mock, logger_mock, publish_messages_mock, database_mock
275275
):
276276
mock_session = MagicMock()
277277
mock_dataset1 = Gtfsdataset(
@@ -302,37 +302,10 @@ def test_extract_location_batch(
302302
mock_session
303303
)
304304

305-
mock_publisher = MagicMock()
306-
publisher_client_mock.return_value = mock_publisher
307-
mock_future = MagicMock()
308-
mock_future.result.return_value = "message_id"
309-
mock_publisher.publish.return_value = mock_future
310-
311305
response = extract_location_batch(None)
312306

313307
logger_mock.init_logger.assert_called_once()
314-
mock_publisher.publish.assert_any_call(
315-
mock.ANY,
316-
json.dumps(
317-
{
318-
"stable_id": "1",
319-
"dataset_id": "stable_1",
320-
"url": "http://example.com/1",
321-
"execution_id": "batch-uuid",
322-
}
323-
).encode("utf-8"),
324-
)
325-
mock_publisher.publish.assert_any_call(
326-
mock.ANY,
327-
json.dumps(
328-
{
329-
"stable_id": "2",
330-
"dataset_id": "stable_2",
331-
"url": "http://example.com/2",
332-
"execution_id": "batch-uuid",
333-
}
334-
).encode("utf-8"),
335-
)
308+
publish_messages_mock.assert_called_once()
336309
self.assertEqual(response, ("Batch function triggered for 2 datasets.", 200))
337310

338311
@mock.patch.dict(

functions-python/helpers/database.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
import os
2020
import threading
21+
from functools import wraps
2122
from typing import Optional, ContextManager
2223

2324
from sqlalchemy import create_engine, text, event, Engine
@@ -85,7 +86,7 @@ def mapper_configure_listener(mapper, class_):
8586
event.listen(mapper, "mapper_configured", mapper_configure_listener)
8687

8788

88-
def with_db_session(func):
89+
def with_db_session(_func=None, *, echo=True):
8990
"""
9091
Decorator to handle the session management for the decorated function.
9192
@@ -103,16 +104,23 @@ def with_db_session(func):
103104
- If 'db_session' is already provided, it simply calls the decorated function with the existing session.
104105
"""
105106

106-
def wrapper(*args, **kwargs):
107-
db_session = kwargs.get("db_session")
108-
if db_session is None:
107+
def decorator(func):
108+
@wraps(func)
109+
def wrapper(*args, **kwargs):
110+
if "db_session" in kwargs:
111+
return func(*args, **kwargs)
112+
109113
db = Database()
110-
with db.start_db_session() as session:
114+
with db.start_db_session(echo=echo) as session:
111115
kwargs["db_session"] = session
112116
return func(*args, **kwargs)
113-
return func(*args, **kwargs)
114117

115-
return wrapper
118+
return wrapper
119+
120+
if _func is None:
121+
return decorator
122+
else:
123+
return decorator(_func)
116124

117125

118126
class Database:

functions-python/helpers/logger.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import os
1617

1718
import google.cloud.logging
1819
from google.cloud.logging_v2 import Client
@@ -47,6 +48,8 @@ def init_logger() -> Client:
4748
"""
4849
Initializes the logger
4950
"""
51+
if os.getenv("DEBUG", "False") == "True":
52+
return None
5053
client = google.cloud.logging.Client()
5154
client.get_default_handler()
5255
client.setup_logging()

functions-python/helpers/pub_sub.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import json
17+
import logging
1618
import uuid
19+
from typing import Dict, List
1720

1821
from google.cloud import pubsub_v1
1922
from google.cloud.pubsub_v1 import PublisherClient
@@ -43,3 +46,15 @@ def get_execution_id(request, prefix: str) -> str:
4346
trace_id = request.headers.get("X-Cloud-Trace-Context")
4447
execution_id = f"{prefix}-{trace_id}" if trace_id else f"{prefix}-{uuid.uuid4()}"
4548
return execution_id
49+
50+
51+
def publish_messages(data: List[Dict], project_id, topic_name) -> None:
52+
"""
53+
Publishes the given data to the Pub/Sub topic.
54+
"""
55+
publisher = get_pubsub_client()
56+
topic_path = publisher.topic_path(project_id, topic_name)
57+
for element in data:
58+
message_data = json.dumps(element).encode("utf-8")
59+
future = publish(publisher, topic_path, message_data)
60+
logging.info(f"Published message to Pub/Sub with ID: {future.result()}")

functions-python/helpers/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ google-api-core
2525
google-cloud-firestore
2626
google-cloud-bigquery
2727

28-
#Additional package
29-
pycountry
28+
# Additional package
29+
pycountry
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[run]
2+
omit =
3+
*/test*/*
4+
*/helpers/*
5+
*/database_gen/*
6+
7+
[report]
8+
exclude_lines =
9+
if __name__ == .__main__.:

0 commit comments

Comments
 (0)