Skip to content

Commit d125ddd

Browse files
committed
Fix GTFS-RT update endpoint when static reference is changed
1 parent a28edc7 commit d125ddd

File tree

9 files changed

+127
-17
lines changed

9 files changed

+127
-17
lines changed

functions-python/helpers/database.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,40 @@
1919
from typing import Final
2020

2121
from sqlalchemy import create_engine, text, event
22-
from sqlalchemy.orm import sessionmaker, mapper
22+
from sqlalchemy.orm import sessionmaker, mapper, class_mapper
2323
import logging
2424

25+
from database_gen.sqlacodegen_models import Feed, Gtfsfeed, Gtfsrealtimefeed, Gbfsfeed
26+
2527
DB_REUSE_SESSION: Final[str] = "DB_REUSE_SESSION"
2628
lock = threading.Lock()
2729
global_session = None
2830

2931

32+
def configure_polymorphic_mappers():
33+
"""
34+
Configure the polymorphic mappers allowing polymorphic values on relationships.
35+
"""
36+
feed_mapper = class_mapper(Feed)
37+
# Configure the polymorphic mapper using date_type as discriminator for the Feed class
38+
feed_mapper.polymorphic_on = Feed.data_type
39+
feed_mapper.polymorphic_identity = Feed.__tablename__.lower()
40+
41+
gtfsfeed_mapper = class_mapper(Gtfsfeed)
42+
gtfsfeed_mapper.inherits = feed_mapper
43+
gtfsfeed_mapper.polymorphic_identity = Gtfsfeed.__tablename__.lower()
44+
45+
gtfsrealtimefeed_mapper = class_mapper(Gtfsrealtimefeed)
46+
gtfsrealtimefeed_mapper.inherits = feed_mapper
47+
gtfsrealtimefeed_mapper.polymorphic_identity = (
48+
Gtfsrealtimefeed.__tablename__.lower()
49+
)
50+
51+
gbfsfeed_mapper = class_mapper(Gbfsfeed)
52+
gbfsfeed_mapper.inherits = feed_mapper
53+
gbfsfeed_mapper.polymorphic_identity = Gbfsfeed.__tablename__.lower()
54+
55+
3056
def set_cascade(mapper, class_):
3157
"""
3258
Set cascade for relationships in Gtfsfeed.
@@ -43,7 +69,16 @@ def set_cascade(mapper, class_):
4369
rel.cascade = "all, delete-orphan"
4470

4571

46-
event.listen(mapper, "mapper_configured", set_cascade)
72+
def mapper_configure_listener(mapper, class_):
73+
"""
74+
Mapper configure listener
75+
"""
76+
set_cascade(mapper, class_)
77+
configure_polymorphic_mappers()
78+
79+
80+
# Add the mapper_configure_listener to the mapper_configured event
81+
event.listen(mapper, "mapper_configured", mapper_configure_listener)
4782

4883

4984
def get_db_engine(database_url: str = None, echo: bool = True):

functions-python/operations_api/src/feeds_operations/impl/models/entity_type_impl.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,17 @@ def from_orm(cls, obj: EntityTypeOrm | None) -> EntityType | None:
2626
return EntityType(obj.name.lower())
2727

2828
@classmethod
29-
def to_orm(cls, entity_type: EntityType) -> EntityTypeOrm:
29+
def to_orm(cls, entity_type: EntityType, session) -> EntityTypeOrm:
3030
"""
3131
Convert a Pydantic model to a SQLAlchemy row object.
3232
"""
33-
return EntityTypeOrm(name=entity_type.name.upper())
33+
result = (
34+
session.query(EntityTypeOrm)
35+
.filter(EntityTypeOrm.name == entity_type.name)
36+
.first()
37+
)
38+
return (
39+
result
40+
if result is not None
41+
else EntityTypeOrm(name=entity_type.name.upper())
42+
)

functions-python/operations_api/src/feeds_operations/impl/models/update_request_gtfs_rt_feed_impl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,10 @@ def to_orm(
148148
entity.entitytypes = (
149149
[]
150150
if update_request.entity_types is None
151-
else [EntityTypeImpl.to_orm(item) for item in update_request.entity_types]
151+
else [
152+
EntityTypeImpl.to_orm(item, session)
153+
for item in update_request.entity_types
154+
]
152155
)
153156
entity.gtfs_feeds = (
154157
[]
Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +0,0 @@
1-
import sys
2-
3-
sys.path.append("..")
4-
5-
import os
6-
7-
print(os.getcwd())

functions-python/operations_api/tests/conftest.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,24 @@
4747
status="active",
4848
feed_contact_email="feed_contact_email",
4949
provider="provider",
50-
# gtfs_rt_feeds=[feed_mdb_41],
50+
gtfs_rt_feeds=[feed_mdb_41],
51+
)
52+
53+
feed_mdb_400 = Gtfsfeed(
54+
id="mdb-400",
55+
data_type="gtfs",
56+
feed_name="London Transit Commission",
57+
note="note",
58+
producer_url="producer_url",
59+
authentication_type="1",
60+
authentication_info_url="authentication_info_url",
61+
api_key_parameter_name="api_key_parameter_name",
62+
license_url="license_url",
63+
stable_id="mdb-400",
64+
status="active",
65+
feed_contact_email="feed_contact_email",
66+
provider="provider",
67+
gtfs_rt_feeds=[],
5168
)
5269

5370

@@ -62,6 +79,7 @@ def populate_database():
6279
session.add(feed_mdb_41)
6380
# session.flush()
6481
session.add(feed_mdb_40)
82+
session.add(feed_mdb_400)
6583
session.commit()
6684

6785

functions-python/operations_api/tests/feeds_operations/impl/models/test_entity_type_impl.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import Mock
2+
13
from database_gen.sqlacodegen_models import Entitytype
24
from feeds_operations.impl.models.entity_type_impl import EntityTypeImpl
35
from feeds_operations_gen.models.entity_type import EntityType
@@ -16,5 +18,10 @@ def test_from_orm_none():
1618

1719
def test_to_orm():
1820
entity_type = EntityType("vp")
19-
result = EntityTypeImpl.to_orm(entity_type)
20-
assert result.name == "VP"
21+
session = Mock()
22+
mock_query = Mock()
23+
resulting_entity = Mock()
24+
mock_query.filter.return_value.first.return_value = resulting_entity
25+
session.query.return_value = mock_query
26+
result = EntityTypeImpl.to_orm(entity_type, session)
27+
assert result == resulting_entity

functions-python/operations_api/tests/feeds_operations/impl/models/test_update_request_gtfs_rt_feed_impl.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Redirectingid,
55
Externalid,
66
Gtfsrealtimefeed,
7+
Entitytype,
78
)
89
from feeds_operations.impl.models.update_request_gtfs_rt_feed_impl import (
910
UpdateRequestGtfsRtFeedImpl,
@@ -85,10 +86,16 @@ def test_to_orm():
8586
)
8687
entity = Gtfsrealtimefeed(id="1", stable_id="stable_id", data_type="gtfs")
8788
target_feed = Gtfsfeed(id=2, stable_id="target_stable_id")
89+
resulting_entity = Entitytype(name="VP")
90+
8891
session = MagicMock()
89-
session.query.return_value.filter.return_value.first.return_value = target_feed
92+
session.query.return_value.filter.return_value.first.side_effect = [
93+
target_feed,
94+
resulting_entity,
95+
]
9096

9197
result = UpdateRequestGtfsRtFeedImpl.to_orm(update_request, entity, session)
98+
assert result is not None
9299
assert result.status == "active"
93100
assert result.provider == "provider"
94101
assert result.feed_name == "feed_name"
@@ -121,8 +128,12 @@ def test_to_orm_invalid_source_info():
121128
)
122129
entity = Gtfsrealtimefeed(id="1", stable_id="stable_id", data_type="gtfs")
123130
target_feed = Gtfsfeed(id=2, stable_id="target_stable_id")
131+
124132
session = MagicMock()
125-
session.query.return_value.filter.return_value.first.return_value = target_feed
133+
session.query.return_value.filter.return_value.first.side_effect = [
134+
target_feed,
135+
None,
136+
]
126137

127138
result = UpdateRequestGtfsRtFeedImpl.to_orm(update_request, entity, session)
128139

functions-python/operations_api/tests/feeds_operations/impl/test_feeds_operations_impl_gtfs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from database_gen.sqlacodegen_models import Gtfsfeed
1010
from feeds_operations.impl.feeds_operations_impl import OperationsApiImpl
1111
from feeds_operations_gen.models.authentication_type import AuthenticationType
12+
from feeds_operations_gen.models.external_id import ExternalId
1213
from feeds_operations_gen.models.feed_status import FeedStatus
1314
from feeds_operations_gen.models.source_info import SourceInfo
1415
from feeds_operations_gen.models.update_request_gtfs_feed import UpdateRequestGtfsFeed
@@ -64,6 +65,12 @@ async def test_update_gtfs_feed_no_changes(_, update_request_gtfs_feed):
6465
@pytest.mark.asyncio
6566
async def test_update_gtfs_feed_field_change(_, update_request_gtfs_feed):
6667
update_request_gtfs_feed.feed_name = "New feed name"
68+
update_request_gtfs_feed.external_ids = [
69+
ExternalId(
70+
external_id="new_external_id",
71+
source="new_source",
72+
)
73+
]
6774
with get_testing_session() as session:
6875
api = OperationsApiImpl()
6976
response: Response = await api.update_gtfs_feed(update_request_gtfs_feed)

functions-python/operations_api/tests/feeds_operations/impl/test_feeds_operations_impl_gtfs_rt.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,30 @@ async def test_update_gtfs_feed_field_change(_, update_request_gtfs_rt_feed):
6464
.one()
6565
)
6666
assert db_feed.feed_name == "New feed name"
67+
68+
69+
@patch("helpers.logger.Logger")
70+
@mock.patch.dict(
71+
os.environ,
72+
{
73+
"FEEDS_DATABASE_URL": default_db_url,
74+
},
75+
)
76+
@pytest.mark.asyncio
77+
async def test_update_gtfs_feed_static_change(_, update_request_gtfs_rt_feed):
78+
update_request_gtfs_rt_feed.feed_references = ["mdb-400"]
79+
with get_testing_session() as session:
80+
api = OperationsApiImpl()
81+
response: Response = await api.update_gtfs_rt_feed(update_request_gtfs_rt_feed)
82+
assert response.status_code == 200
83+
84+
db_feed = (
85+
session.query(Gtfsrealtimefeed)
86+
.filter(Gtfsrealtimefeed.stable_id == feed_mdb_41.stable_id)
87+
.one()
88+
)
89+
assert len(db_feed.gtfs_feeds) == 1
90+
feed = next(
91+
(feed for feed in db_feed.gtfs_feeds if feed.stable_id == "mdb-400"), None
92+
)
93+
assert feed is not None, "Feed with stable ID 'mdb-400' not found"

0 commit comments

Comments
 (0)