Skip to content

Commit 12b6511

Browse files
authored
feat: update update_feed_status to return status diff counts (#992)
1 parent 6168402 commit 12b6511

File tree

5 files changed

+178
-50
lines changed

5 files changed

+178
-50
lines changed

functions-python/backfill_dataset_service_date_range/tests/test_backfill_dataset_service_date_range_main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def test_backfill_dataset_service_date_range(mock_backfill_datasets, mock_logger
420420
with patch.dict(os.environ, {"FEEDS_DATABASE_URL": default_db_url}):
421421
response_body, status_code = backfill_dataset_service_date_range(None)
422422

423-
mock_backfill_datasets.asser_called_once()
423+
mock_backfill_datasets.assert_called_once()
424424
assert response_body == "Script executed successfully. 5 datasets updated"
425425
assert status_code == 200
426426

@@ -435,7 +435,7 @@ def test_backfill_dataset_service_date_range_error_raised(
435435
with patch.dict(os.environ, {"FEEDS_DATABASE_URL": default_db_url}):
436436
response_body, status_code = backfill_dataset_service_date_range(None)
437437

438-
mock_backfill_datasets.asser_called_once()
438+
mock_backfill_datasets.assert_called_once()
439439
assert (
440440
response_body
441441
== "Error setting the datasets service date range values: Mocked exception"

functions-python/test_utils/database_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def get_testing_engine() -> Engine:
5050
return db._get_engine(echo=False)
5151

5252

53-
def get_testing_session() -> Session:
53+
def get_testing_session(echo: bool = False) -> Session:
5454
"""Returns a SQLAlchemy session for the test db."""
5555
db = Database(database_url=default_db_url)
56-
return db._get_session(echo=False)()
56+
return db._get_session(echo=echo)()
5757

5858

5959
def clean_testing_db():

functions-python/update_feed_status/src/main.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from shared.helpers.logger import Logger
66
from shared.helpers.database import Database
77
from typing import TYPE_CHECKING
8-
from sqlalchemy import case, text
8+
from sqlalchemy import text
99
from shared.database_gen.sqlacodegen_models import Gtfsdataset, Feed, t_feedsearch
1010
from shared.helpers.database import refresh_materialized_view
1111

@@ -33,32 +33,40 @@ def update_feed_statuses_query(session: "Session"):
3333
.subquery()
3434
)
3535

36-
new_status = case(
36+
status_conditions = [
3737
(
3838
latest_dataset_subq.c.service_date_range_end < today_utc,
39-
text("'inactive'::status"),
39+
"inactive",
4040
),
4141
(
4242
latest_dataset_subq.c.service_date_range_start > today_utc,
43-
text("'future'::status"),
43+
"future",
4444
),
4545
(
4646
(latest_dataset_subq.c.service_date_range_start <= today_utc)
4747
& (latest_dataset_subq.c.service_date_range_end >= today_utc),
48-
text("'active'::status"),
48+
"active",
4949
),
50-
)
50+
]
5151

5252
try:
53-
updated_count = (
54-
session.query(Feed)
55-
.filter(
56-
Feed.status != text("'deprecated'::status"),
57-
Feed.status != text("'development'::status"),
58-
Feed.id == latest_dataset_subq.c.feed_id,
53+
diff_counts: dict[str, int] = {}
54+
55+
for service_date_conditions, status in status_conditions:
56+
diff_counts[status] = (
57+
session.query(Feed)
58+
.filter(
59+
Feed.id == latest_dataset_subq.c.feed_id,
60+
Feed.status != text("'deprecated'::status"),
61+
Feed.status != text("'development'::status"),
62+
# We filter out feeds that already have the status so that the
63+
# update count reflects the number of feeds that actually
64+
# changed status.
65+
Feed.status != text("'%s'::status" % status),
66+
service_date_conditions,
67+
)
68+
.update({Feed.status: status}, synchronize_session=False)
5969
)
60-
.update({Feed.status: new_status}, synchronize_session=False)
61-
)
6270
except Exception as e:
6371
logging.error(f"Error updating feed statuses: {e}")
6472
raise Exception(f"Error updating feed statuses: {e}")
@@ -68,7 +76,7 @@ def update_feed_statuses_query(session: "Session"):
6876
refresh_materialized_view(session, t_feedsearch.name)
6977
logging.info("Feed Database changes committed.")
7078
session.close()
71-
return updated_count
79+
return diff_counts
7280
except Exception as e:
7381
logging.error("Error committing changes:", e)
7482
session.rollback()
@@ -81,14 +89,12 @@ def update_feed_status(_):
8189
"""Updates the Feed status based on the latets dataset service date range."""
8290
Logger.init_logger()
8391
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
84-
update_count = 0
8592
try:
8693
with db.start_db_session() as session:
8794
logging.info("Database session started.")
88-
update_count = update_feed_statuses_query(session)
95+
diff_counts = update_feed_statuses_query(session)
96+
return diff_counts, 200
8997

9098
except Exception as error:
9199
logging.error(f"Error updating the feed statuses: {error}")
92100
return f"Error updating the feed statuses: {error}", 500
93-
94-
return f"Script executed successfully. {update_count} feeds updated", 200
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#
2+
# MobilityData 2025
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
from datetime import datetime, timedelta
18+
from uuid import uuid4
19+
from shared.database_gen.sqlacodegen_models import (
20+
Feed,
21+
Gtfsdataset,
22+
)
23+
from test_shared.test_utils.database_utils import (
24+
clean_testing_db,
25+
get_testing_session,
26+
)
27+
28+
future_date = datetime.now() + timedelta(days=15)
29+
past_date = datetime.now() - timedelta(days=15)
30+
31+
32+
def make_dataset(
33+
feed_id: str, latest: bool, start: datetime, end: datetime
34+
) -> Gtfsdataset:
35+
return Gtfsdataset(
36+
id=str(uuid4()),
37+
feed_id=feed_id,
38+
latest=latest,
39+
service_date_range_start=start,
40+
service_date_range_end=end,
41+
)
42+
43+
44+
def populate_database():
45+
session = get_testing_session()
46+
47+
id_range_by_status = {
48+
"inactive": (0, 6),
49+
"active": (7, 10),
50+
"deprecated": (11, 15),
51+
"development": (16, 17),
52+
"future": (18, 29),
53+
}
54+
for status, (a, b) in id_range_by_status.items():
55+
for _id in map(str, range(a, b + 1)):
56+
session.add(Feed(id=str(_id), status=status))
57+
58+
# -> inactive
59+
for _id in [
60+
"0", # already inactive
61+
"7",
62+
"8",
63+
"22",
64+
]:
65+
session.add(make_dataset(_id, True, past_date, past_date))
66+
67+
# -> active
68+
for _id in [
69+
"2",
70+
"9", # already active
71+
"12", # deprecated
72+
"16", # development
73+
"25",
74+
]:
75+
session.add(make_dataset(_id, True, past_date, future_date))
76+
77+
# -> future
78+
for _id in [
79+
"10",
80+
]:
81+
session.add(make_dataset(_id, False, past_date, future_date))
82+
session.add(make_dataset(_id, True, future_date, future_date))
83+
84+
session.commit()
85+
86+
87+
def pytest_sessionstart(session):
88+
clean_testing_db()
89+
populate_database()
Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,63 @@
11
from unittest.mock import patch, MagicMock
2-
from test_shared.test_utils.database_utils import default_db_url
3-
from main import update_feed_status, update_feed_statuses_query
2+
from test_shared.test_utils.database_utils import default_db_url, get_testing_session
3+
from main import (
4+
update_feed_status,
5+
update_feed_statuses_query,
6+
)
7+
from shared.database_gen.sqlacodegen_models import Feed
48
from datetime import date, timedelta
9+
from typing import Iterator, NamedTuple
510

611
import os
712

8-
9-
def test_update_feed_status_return():
10-
mock_session = MagicMock()
11-
12-
today = date(2025, 3, 1)
13-
14-
mock_subquery = MagicMock()
15-
mock_subquery.c.feed_id = 1
16-
mock_subquery.c.service_date_range_start = today - timedelta(days=10)
17-
mock_subquery.c.service_date_range_end = today + timedelta(days=10)
18-
19-
mock_query = mock_session.query.return_value
20-
mock_query.filter.return_value.subquery.return_value = mock_subquery
21-
22-
mock_update_query = mock_session.query.return_value.filter.return_value
23-
mock_update_query.update.return_value = 3
24-
25-
updated_count = update_feed_statuses_query(mock_session)
26-
27-
assert updated_count == 3
28-
mock_session.commit.assert_called_once()
13+
from sqlalchemy import text
14+
from sqlalchemy.orm import Session
15+
16+
17+
class PartialFeed(NamedTuple):
18+
"""
19+
Subset of the Feed entity with only the fields queried in `fetch_feeds`.
20+
"""
21+
22+
id: str
23+
status: str
24+
25+
26+
def fetch_feeds(session: Session) -> Iterator[PartialFeed]:
27+
# When adding or removing fields here, `PartialFeed` should be updated to
28+
# match, for type safety.
29+
query = session.query(Feed.id, Feed.status).filter(
30+
Feed.status != text("'deprecated'::status"),
31+
Feed.status != text("'development'::status"),
32+
)
33+
for feed in query:
34+
yield PartialFeed(id=feed.id, status=feed.status)
35+
36+
37+
def test_update_feed_status():
38+
session = get_testing_session()
39+
feeds_before: dict[str, PartialFeed] = {f.id: f for f in fetch_feeds(session)}
40+
result = dict(update_feed_statuses_query(session))
41+
assert result == {
42+
"inactive": 3,
43+
"active": 2,
44+
"future": 1,
45+
}
46+
47+
feeds_after: dict[str, PartialFeed] = {f.id: f for f in fetch_feeds(session)}
48+
expected_status_changes = {
49+
"2": "active",
50+
"7": "inactive",
51+
"8": "inactive",
52+
"10": "future",
53+
"22": "inactive",
54+
"25": "active",
55+
}
56+
for feed_id, feed_before in feeds_before.items():
57+
feed_after = feeds_after[feed_id]
58+
assert feed_after.status == expected_status_changes.get(
59+
feed_id, feed_before.status
60+
)
2961

3062

3163
def test_update_feed_status_failed_query():
@@ -53,13 +85,14 @@ def test_update_feed_status_failed_query():
5385
@patch("main.Logger", autospec=True)
5486
@patch("main.update_feed_statuses_query")
5587
def test_updated_feed_status(mock_update_query, mock_logger):
56-
mock_update_query.return_value = 5
88+
return_value = {"active": 5}
89+
mock_update_query.return_value = return_value
5790

5891
with patch.dict(os.environ, {"FEEDS_DATABASE_URL": default_db_url}):
5992
response_body, status_code = update_feed_status(None)
6093

61-
mock_update_query.asser_called_once()
62-
assert response_body == "Script executed successfully. 5 feeds updated"
94+
mock_update_query.assert_called_once()
95+
assert response_body == return_value
6396
assert status_code == 200
6497

6598

@@ -71,6 +104,6 @@ def test_updated_feed_status_error_raised(mock_update_query, mock_logger):
71104
with patch.dict(os.environ, {"FEEDS_DATABASE_URL": default_db_url}):
72105
response_body, status_code = update_feed_status(None)
73106

74-
mock_update_query.asser_called_once()
107+
mock_update_query.assert_called_once()
75108
assert response_body == "Error updating the feed statuses: Mocked exception"
76109
assert status_code == 500

0 commit comments

Comments
 (0)