diff --git a/app/dao/fact_notification_status_dao.py b/app/dao/fact_notification_status_dao.py index 9746d191a4..b3ad308fc9 100644 --- a/app/dao/fact_notification_status_dao.py +++ b/app/dao/fact_notification_status_dao.py @@ -135,10 +135,12 @@ def fetch_notification_status_for_service_for_day(bst_day, service_id): ) -def fetch_notification_status_for_service_for_today_and_7_previous_days(service_id, by_template=False, limit_days=7): +def fetch_notification_status_for_service_for_today_and_7_previous_days( + service_id, by_template=False, limit_days=7, session=db.session +): start_date = midnight_n_days_ago(limit_days) now = datetime.utcnow() - stats_for_7_days = db.session.query( + stats_for_7_days = session.query( FactNotificationStatus.notification_type.label("notification_type"), FactNotificationStatus.notification_status.label("status"), *([FactNotificationStatus.template_id.label("template_id")] if by_template else []), @@ -150,7 +152,7 @@ def fetch_notification_status_for_service_for_today_and_7_previous_days(service_ ) stats_for_today = ( - db.session.query( + session.query( Notification.notification_type.cast(db.Text), Notification.status, *([Notification.template_id] if by_template else []), @@ -169,7 +171,7 @@ def fetch_notification_status_for_service_for_today_and_7_previous_days(service_ all_stats_table = stats_for_7_days.union_all(stats_for_today).subquery() aggregation = ( - db.session.query( + session.query( *([all_stats_table.c.template_id] if by_template else []), all_stats_table.c.notification_type, all_stats_table.c.status, @@ -183,7 +185,7 @@ def fetch_notification_status_for_service_for_today_and_7_previous_days(service_ .subquery() ) - query = db.session.query( + query = session.query( *( [Template.name.label("template_name"), Template.is_precompiled_letter, aggregation.c.template_id] if by_template diff --git a/app/template_statistics/rest.py b/app/template_statistics/rest.py index a9271f3f05..3bbe2a00ba 100644 --- a/app/template_statistics/rest.py +++ b/app/template_statistics/rest.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from app import db from app.dao.fact_notification_status_dao import ( fetch_notification_status_for_service_for_today_and_7_previous_days, ) @@ -26,7 +27,7 @@ def get_template_statistics_for_service_by_day(service_id): if whole_days < 0 or whole_days > 7: raise InvalidRequest({"whole_days": ["whole_days must be between 0 and 7"]}, status_code=400) data = fetch_notification_status_for_service_for_today_and_7_previous_days( - service_id, by_template=True, limit_days=whole_days + service_id, by_template=True, limit_days=whole_days, session=db.session_bulk ) return jsonify( diff --git a/tests/app/dao/test_fact_notification_status_dao.py b/tests/app/dao/test_fact_notification_status_dao.py index d042fb880b..834da476af 100644 --- a/tests/app/dao/test_fact_notification_status_dao.py +++ b/tests/app/dao/test_fact_notification_status_dao.py @@ -5,6 +5,7 @@ import pytest from freezegun import freeze_time +from app import db from app.constants import ( EMAIL_TYPE, KEY_TYPE_TEAM, @@ -41,6 +42,7 @@ create_service, create_template, ) +from tests.utils import QueryRecorder def test_fetch_notification_status_for_service_by_month(notify_db_session): @@ -134,7 +136,16 @@ def test_fetch_notification_status_for_service_for_day(notify_db_session): @freeze_time("2018-10-31T18:00:00") -def test_fetch_notification_status_for_service_for_today_and_7_previous_days(notify_db_session): +@pytest.mark.parametrize( + "sess,expected_bind_key", + ( + (db.session, None), + (db.session_bulk, "bulk"), + ), +) +def test_fetch_notification_status_for_service_for_today_and_7_previous_days( + notify_db_session, sess, expected_bind_key +): service_1 = create_service(service_name="service_1") sms_template = create_template(service=service_1, template_type=SMS_TYPE) sms_template_2 = create_template(service=service_1, template_type=SMS_TYPE) @@ -154,10 +165,14 @@ def test_fetch_notification_status_for_service_for_today_and_7_previous_days(not # too early, shouldn't be included create_notification(service_1.templates[0], created_at=datetime(2018, 10, 30, 12, 0, 0), status="delivered") - results = sorted( - fetch_notification_status_for_service_for_today_and_7_previous_days(service_1.id), - key=lambda x: (x.notification_type, x.status), - ) + service_id = service_1.id + with QueryRecorder() as query_recorder: + results = sorted( + fetch_notification_status_for_service_for_today_and_7_previous_days(service_id=service_id, session=sess), + key=lambda x: (x.notification_type, x.status), + ) + + assert {query_info.bind_key for query_info in query_recorder.queries} == {expected_bind_key} assert len(results) == 4 diff --git a/tests/app/template_statistics/test_rest.py b/tests/app/template_statistics/test_rest.py index fd4d0903e7..a96cdc8826 100644 --- a/tests/app/template_statistics/test_rest.py +++ b/tests/app/template_statistics/test_rest.py @@ -62,7 +62,9 @@ def test_get_template_statistics_for_service_by_day_accepts_old_query_string( @freeze_time("2018-01-02 12:00:00") -def test_get_template_statistics_for_service_by_day_goes_to_db(admin_request, mocker, sample_template): +def test_get_template_statistics_for_service_by_day_goes_to_db( + admin_request, mocker, sample_template, notify_db_session_bulk +): # first time it is called redis returns data, second time returns none mock_dao = mocker.patch( "app.template_statistics.rest.fetch_notification_status_for_service_for_today_and_7_previous_days", @@ -94,7 +96,9 @@ def test_get_template_statistics_for_service_by_day_goes_to_db(admin_request, mo } ] # dao only called for 2nd, since redis returned values for first call - mock_dao.assert_called_once_with(str(sample_template.service_id), limit_days=1, by_template=True) + mock_dao.assert_called_once_with( + str(sample_template.service_id), limit_days=1, by_template=True, session=notify_db_session_bulk + ) def test_get_template_statistics_for_service_by_day_returns_empty_list_if_no_templates( diff --git a/tests/conftest.py b/tests/conftest.py index b06c21d57b..29dde0c3a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -143,6 +143,17 @@ def sms_providers(_notify_db): get_provider_details_by_identifier("firetext").priority = 0 +@pytest.fixture(scope="function") +def notify_db_session_bulk(_notify_db, sms_providers): + """ + This fixture clears down all non static data after your test run. It yields the SQLAlchemy bulk session variable, + which is used for bulk/replica DB operations. Use this session to manually route queries to the replica database. + """ + yield _notify_db.session_bulk + + _clean_database(_notify_db) + + @pytest.fixture(scope="function") def notify_db_session(_notify_db, sms_providers): """ diff --git a/tests/utils.py b/tests/utils.py index 914f274001..14c4de8e54 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,14 +1,46 @@ -import flask_sqlalchemy +from dataclasses import dataclass + +from sqlalchemy import event + +from app import db + + +@dataclass +class QueryInfo: + statement: str + parameters: tuple | dict | None + bind_key: str | None class QueryRecorder: def __init__(self): - self.queries = [] - self._count_on_enter = None + self.queries: list[QueryInfo] = [] + self._listeners = [] def __enter__(self): - self._count_on_enter = len(flask_sqlalchemy.record_queries.get_recorded_queries()) + # Register listeners for all engines to capture bind_key + for bind_key, engine in db.engines.items(): + listener = self._listener(bind_key) + event.listen(engine, "before_cursor_execute", listener) + self._listeners.append((engine, listener)) return self def __exit__(self, exc_type, exc_val, exc_tb): - self.queries = flask_sqlalchemy.record_queries.get_recorded_queries()[self._count_on_enter :] + # Remove all listeners + for engine, listener in self._listeners: + event.remove(engine, "before_cursor_execute", listener) + self._listeners.clear() + + def _listener(self, bind_key): + """Create a listener function that captures the bind_key in its closure.""" + + def listener(conn, cursor, statement, parameters, context, executemany): + self.queries.append( + QueryInfo( + statement=statement, + parameters=parameters, + bind_key=bind_key, + ) + ) + + return listener