Skip to content

Commit 1ed7081

Browse files
committed
added local tests
1 parent 67f7715 commit 1ed7081

File tree

5 files changed

+32
-112
lines changed

5 files changed

+32
-112
lines changed

api/src/scripts/populate_db_gtfs.py

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -118,95 +118,32 @@ def process_entity_types(self, session: "Session", feed: Gtfsrealtimefeed, row,
118118
self.logger.warning(f"Entity types array is empty for feed {stable_id}")
119119
feed.entitytypes.clear()
120120

121-
# def process_feed_references(self, session: "Session"):
122-
# """
123-
# Process the feed references
124-
# """
125-
# self.logger.info("Processing feed references")
126-
# for index, row in self.df.iterrows():
127-
# stable_id = self.get_stable_id(row)
128-
# data_type = self.get_data_type(row)
129-
# if data_type != "gtfs_rt":
130-
# continue
131-
# gtfs_rt_feed = self.query_feed_by_stable_id(session, stable_id, "gtfs_rt")
132-
# static_reference = self.get_safe_value(row, "static_reference", "")
133-
# if static_reference:
134-
# try:
135-
# gtfs_stable_id = f"mdb-{int(float(static_reference))}"
136-
# except ValueError:
137-
# gtfs_stable_id = static_reference
138-
# gtfs_feed = self.query_feed_by_stable_id(session, gtfs_stable_id, "gtfs")
139-
# if not gtfs_feed:
140-
# self.logger.warning(f"Could not find static reference feed {gtfs_stable_id} for feed {stable_id}")
141-
# continue
142-
# already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds}
143-
# if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids:
144-
# gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed)
145-
# # Flush to avoid FK violation
146-
# session.flush()
147-
148121
def process_feed_references(self, session: "Session"):
149122
"""
150-
Process the feed references for GTFS-RT feeds.
151-
152-
1. Uses 'static_reference' column if present.
153-
2. Falls back to matching static feeds by provider name.
123+
Process the feed references
154124
"""
155125
self.logger.info("Processing feed references")
156-
157126
for index, row in self.df.iterrows():
158127
stable_id = self.get_stable_id(row)
159128
data_type = self.get_data_type(row)
160-
161-
# Only process GTFS-RT feeds
162129
if data_type != "gtfs_rt":
163130
continue
164-
165131
gtfs_rt_feed = self.query_feed_by_stable_id(session, stable_id, "gtfs_rt")
166-
if not gtfs_rt_feed:
167-
self.logger.warning(f"Could not find GTFS-RT feed {stable_id}")
168-
continue
169-
170-
# Try static_reference column first
171-
static_reference = self.get_safe_value(row, "static_reference", "").strip()
172-
gtfs_feed = None
173-
132+
static_reference = self.get_safe_value(row, "static_reference", "")
174133
if static_reference:
175-
# Normalize stable_id
176134
try:
177135
gtfs_stable_id = f"mdb-{int(float(static_reference))}"
178136
except ValueError:
179137
gtfs_stable_id = static_reference
180-
181138
gtfs_feed = self.query_feed_by_stable_id(session, gtfs_stable_id, "gtfs")
182139
if not gtfs_feed:
183140
self.logger.warning(f"Could not find static reference feed {gtfs_stable_id} for feed {stable_id}")
184-
185-
# Fallback: match by provider if no static_reference or not found
186-
if not gtfs_feed:
187-
provider_value = (self.get_safe_value(row, "provider", "") or "").strip().lower()
188-
if provider_value:
189-
gtfs_feed = (
190-
session.query(Gtfsfeed)
191-
.filter(
192-
Gtfsfeed.data_type == "gtfs",
193-
func.lower(func.trim(Gtfsfeed.provider)) == provider_value,
194-
Gtfsfeed.stable_id != stable_id,
195-
)
196-
.first()
197-
)
198-
if not gtfs_feed:
199-
self.logger.warning(
200-
f"No static GTFS feed found for provider '{provider_value}' for feed {stable_id}"
201-
)
202-
203-
# Link the feeds if we have a valid static GTFS feed
204-
if gtfs_feed:
141+
continue
205142
already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds}
206-
if gtfs_rt_feed.id not in already_referenced_ids:
207-
gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed)
208-
session.flush() # Avoid FK violations
209-
self.logger.info(f"Linked GTFS-RT feed {stable_id} to static feed {gtfs_feed.stable_id}")
143+
if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids:
144+
gtfs_rt_feed.gtfs_feeds = [gtfs_feed]
145+
#gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed) # Flush to avoid FK violation
146+
session.flush()
210147

211148
def process_redirects(self, session: "Session"):
212149
"""

api/tests/integration/populate_tests/README.md

Whitespace-only changes.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mdb_source_id,data_type,entity_type,location.country_code,location.subdivision_name,location.municipality,provider,is_official,name,note,feed_contact_email,static_reference,urls.direct_download,urls.authentication_type,urls.authentication_info,urls.api_key_parameter_name,urls.latest,urls.license,location.bounding_box.minimum_latitude,location.bounding_box.maximum_latitude,location.bounding_box.minimum_longitude,location.bounding_box.maximum_longitude,location.bounding_box.extracted_on,status,features,redirect.id,redirect.comment
22
40,gtfs,,CA,Ontario,London,London Transit Commission,,,,[email protected],,http://www.londontransit.ca/gtfsfeed/google_transit.zip,0,,,https://storage.googleapis.com/storage/v1/b/mdb-latest/o/ca-ontario-london-transit-commission-gtfs-2.zip?alt=media,https://www.londontransit.ca/open-data/ltcs-open-data-terms-of-use/,42.905244,43.051188,-81.36311,-81.137591,2022-02-22T19:51:34+00:00,inactive,,,
33
50,gtfs,,CA,Ontario,Barrie,ZBarrie Transit,,,,,,http://www.myridebarrie.ca/gtfs/Google_transit.zip,,,,https://storage.googleapis.com/storage/v1/b/mdb-latest/o/ca-ontario-barrie-transit-gtfs-3.zip?alt=media,https://www.barrie.ca/services-payments/transportation-parking/barrie-transit/barrie-gtfs,44.3218044,44.42020676,-79.74063237,-79.61089569,2022-03-01T22:43:25+00:00,deprecated,,40|mdb-702,Some|Comment
4-
1562,gtfs-rt,sa,CA,BC,Vancouver,Vancouver-Transit(éèàçíóúČ),TRUE,Realtime(ŘŤÜÎ),,,40,http://foo.org/google_transit.zip,0,,,,,,,,,,active,,10,
4+
1562,gtfs-rt,sa,CA,BC,Vancouver,Vancouver-Transit(éèàçíóúČ),TRUE,Realtime(ŘŤÜÎ),,,50,http://foo.org/google_transit.zip,0,,,,,,,,,,active,,10,
55
1563,gtfs-rt,tu,US,SomeState,SomeCity,SomeCity Bus,FALSE,RT,,,mdb-50,http://bar.com,0,,,,,,,,,,inactive,,10,

api/tests/integration/populate_tests/test_populate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,27 @@ def test_is_official_overwrite(client: TestClient, values):
5353
assert response.status_code == 200
5454
json_response = response.json()
5555
assert json_response["official"] is expected_official, values["assert_fail_message"]
56+
57+
def test_is_feed_reference_overwrite(client: TestClient):
58+
feed_id = "mdb-1562"
59+
response = client.request(
60+
"GET",
61+
"/v1/gtfs_rt_feeds/{id}".format(id=feed_id),
62+
headers=authHeaders,
63+
)
64+
json_response = response.json()
65+
assert json_response["feed_references"] == ["mdb-50"]
66+
67+
feed_id = "mdb-1563"
68+
response = client.request(
69+
"GET",
70+
"/v1/gtfs_rt_feeds/{id}".format(id=feed_id),
71+
headers=authHeaders,
72+
)
73+
json_response = response.json()
74+
assert json_response["feed_references"] == ["mdb-50"]
75+
76+
77+
78+
79+

api/tests/unittest/models/test_gtfs_rt_feed_impl.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -105,47 +105,6 @@ def test_from_orm_all_fields(self):
105105
result = GtfsRTFeedImpl.from_orm(gtfs_rt_feed_orm, db_session=mock_session)
106106
assert result == expected_gtfs_rt_feed_result
107107

108-
# def test_from_orm_feed_references_location_filter(self):
109-
# """
110-
# Test that feed_references are correctly filtered based on shared locations.
111-
# """
112-
# # Define locations
113-
# location_de = Location(id="loc_de", country_code="US", subdivision_name="Delaware")
114-
# location_md = Location(id="loc_md", country_code="US", subdivision_name="Maryland")
115-
# location_ia = Location(id="loc_ia", country_code="US", subdivision_name="Iowa")
116-
117-
# # Define the GTFS-RT feed (e.g., mdb-1771)
118-
# rt_feed = Gtfsrealtimefeed(
119-
# stable_id="mdb-1771",
120-
# provider="DART",
121-
# locations=[location_de, location_md],
122-
# entitytypes=[],
123-
# )
124-
125-
# # Define a correct related schedule feed (e.g., mdb-1235)
126-
# correct_schedule_feed = Gtfsfeed(stable_id="mdb-1235", provider="DART", locations=[location_de, location_md])
127-
128-
# # Define an incorrect schedule feed with a different location (e.g., mdb-193)
129-
# incorrect_schedule_feed = Gtfsfeed(stable_id="mdb-193", provider="DART", locations=[location_ia])
130-
131-
# # Mock the database session and its query
132-
# mock_session = MagicMock()
133-
# mock_query = MagicMock()
134-
# mock_session.query.return_value = mock_query
135-
# # The query inside from_orm should return both schedule feeds before filtering
136-
# mock_query.filter.return_value.options.return_value.all.return_value = [
137-
# correct_schedule_feed,
138-
# incorrect_schedule_feed,
139-
# ]
140-
141-
# # Execute the method
142-
# result = GtfsRTFeedImpl.from_orm(rt_feed, db_session=mock_session)
143-
144-
# # Assert that only the correct feed reference is included
145-
# self.assertIn("mdb-1235", result.feed_references)
146-
# self.assertNotIn("mdb-193", result.feed_references)
147-
# self.assertEqual(len(result.feed_references), 1)
148-
149108
def test_from_orm_empty_fields(self):
150109
"""Test the `from_orm` method with not provided fields."""
151110
# Test with empty fields and None values

0 commit comments

Comments
 (0)