Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
from app.clients.email.aws_ses_stub import AwsSesStubClient
from app.clients.letter.dvla import DVLAClient
from app.clients.sms.firetext import FiretextClient
from app.clients.sms.firetext_stub import FiretextStubClient
from app.clients.sms.mmg import MMGClient
from app.clients.sms.mmg_stub import MMGStubClient
from app.session import BindForcingSession

Base = declarative_base()
Expand Down Expand Up @@ -90,6 +92,32 @@
memo_resetters.append(lambda: get_mmg_client.clear())
mmg_client = LocalProxy(get_mmg_client)

_firetext_stub_client_context_var: ContextVar[FiretextStubClient] = ContextVar("firetext_stub_client")
get_firetext_stub_client: LazyLocalGetter[FiretextStubClient] = LazyLocalGetter(
_firetext_stub_client_context_var,
lambda: FiretextStubClient(
current_app,
statsd_client=statsd_client,
stub_url=current_app.config["FIRETEXT_STUB_URL"],
),
expected_type=FiretextStubClient,
)
memo_resetters.append(lambda: get_firetext_stub_client.clear())
firetext_stub_client = LocalProxy(get_firetext_stub_client)

_mmg_stub_client_context_var: ContextVar[MMGStubClient] = ContextVar("mmg_stub_client")
get_mmg_stub_client: LazyLocalGetter[MMGStubClient] = LazyLocalGetter(
_mmg_stub_client_context_var,
lambda: MMGStubClient(
current_app,
statsd_client=statsd_client,
stub_url=current_app.config["MMG_STUB_URL"],
),
expected_type=MMGStubClient,
)
memo_resetters.append(lambda: get_mmg_stub_client.clear())
mmg_stub_client = LocalProxy(get_mmg_stub_client)

_aws_ses_client_context_var: ContextVar[AwsSesClient] = ContextVar("aws_ses_client")
get_aws_ses_client: LazyLocalGetter[AwsSesClient] = LazyLocalGetter(
_aws_ses_client_context_var,
Expand Down Expand Up @@ -119,18 +147,19 @@
_notification_provider_clients_context_var,
lambda: NotificationProviderClients(
sms_clients={
getter.expected_type.name: LocalProxy(getter)
for getter in (
get_firetext_client,
get_mmg_client,
)
},
"firetext": LocalProxy(get_firetext_client),
"mmg": LocalProxy(get_mmg_client),
}
| (
{"firetext-stub": LocalProxy(get_firetext_stub_client)}
if current_app.config.get("FIRETEXT_STUB_URL")
else {}
)
| ({"mmg-stub": LocalProxy(get_mmg_stub_client)} if current_app.config.get("MMG_STUB_URL") else {}),
email_clients={
getter.expected_type.name: LocalProxy(getter)
# If a stub url is provided for SES, then use the stub client rather
# than the real SES boto client
for getter in ((get_aws_ses_stub_client,) if current_app.config["SES_STUB_URL"] else (get_aws_ses_client,))
},
"ses": LocalProxy(get_aws_ses_client),
}
| ({"ses-stub": LocalProxy(get_aws_ses_stub_client)} if current_app.config.get("SES_STUB_URL") else {}),
),
)
memo_resetters.append(lambda: get_notification_provider_clients.clear())
Expand Down
43 changes: 40 additions & 3 deletions app/celery/provider_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from app.delivery import send_to_providers
from app.exceptions import NotificationTechnicalFailureException
from app.letters.utils import LetterPDFNotFound, find_letter_pdf_in_s3
from app.models import Notification
from app.provider_selection import get_allowed_providers


@notify_celery.task(
Expand Down Expand Up @@ -142,6 +144,9 @@ def deliver_letter(self, notification_id):
)
return

if _handle_requested_letter_provider(notification):
return

try:
file_bytes = find_letter_pdf_in_s3(notification).get()["Body"].read()
except (BotoClientError, LetterPDFNotFound) as e:
Expand Down Expand Up @@ -194,16 +199,48 @@ def deliver_letter(self, notification_id):
raise NotificationTechnicalFailureException(f"Error when sending letter notification {notification_id}") from e


def update_letter_to_sending(notification):
provider = get_provider_details_by_notification_type(LETTER_TYPE)[0]
def update_letter_to_sending(notification, provider_identifier=None):
if provider_identifier is None:
provider = get_provider_details_by_notification_type(LETTER_TYPE)[0]
provider_identifier = provider.identifier

notification.status = NOTIFICATION_SENDING
notification.sent_at = datetime.utcnow()
notification.sent_by = provider.identifier
notification.sent_by = provider_identifier

notifications_dao.dao_update_notification(notification)


def _handle_requested_letter_provider(notification: Notification) -> bool:
"""Handle explicit provider routing for letters.

Returns True if the notification was handled (e.g., dvla-stub), False otherwise.
"""
if not notification.provider_requested:
return False

allowed = get_allowed_providers(LETTER_TYPE)
if notification.provider_requested not in allowed:
current_app.logger.error(
"Requested provider %s is not available for letter notifications",
notification.provider_requested,
extra={
"notification_id": notification.id,
"provider_requested": notification.provider_requested,
},
)
update_notification_status_by_id(notification.id, NOTIFICATION_TECHNICAL_FAILURE)
raise NotificationTechnicalFailureException(
f"Requested provider {notification.provider_requested} is not available for letter notifications"
)

if notification.provider_requested == "dvla-stub":
update_letter_to_sending(notification, provider_identifier="dvla-stub")
return True

return False


def _get_callback_url(notification_id: UUID) -> str:
signed_notification_id = signing.encode(str(notification_id))

Expand Down
4 changes: 4 additions & 0 deletions app/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def get_id_task_args_kwargs_for_job_row(row, template, job, service, sender_id=N
"personalisation": dict(row.personalisation),
# row.recipient_and_personalisation gets all columns for the row, even those not in template placeholders
"client_reference": dict(row.recipient_and_personalisation).get("reference", None),
"provider_requested": job.provider_requested,
}
)

Expand Down Expand Up @@ -346,6 +347,7 @@ def save_sms(
notification_id=notification_id,
reply_to_text=reply_to_text,
client_reference=notification.get("client_reference", None),
provider_requested=notification.get("provider_requested"),
**extra_args,
)

Expand Down Expand Up @@ -432,6 +434,7 @@ def save_email(self, service_id, notification_id, encoded_notification, sender_i
notification_id=notification_id,
reply_to_text=reply_to_text,
client_reference=notification.get("client_reference", None),
provider_requested=notification.get("provider_requested"),
)

provider_tasks.deliver_email.apply_async(
Expand Down Expand Up @@ -491,6 +494,7 @@ def save_letter(
client_reference=notification.get("client_reference", None),
reply_to_text=template.reply_to_text,
status=NOTIFICATION_CREATED,
provider_requested=notification.get("provider_requested"),
)

letters_pdf_tasks.get_pdf_for_templated_letter.apply_async(
Expand Down
2 changes: 1 addition & 1 deletion app/clients/email/aws_ses_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class AwsSesStubClient(EmailClient):
This class is not thread-safe.
"""

name = "ses"
name = "ses-stub"

def __init__(self, region, statsd_client, stub_url):
super().__init__()
Expand Down
60 changes: 60 additions & 0 deletions app/clients/sms/firetext_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import json
from time import monotonic

import requests
from flask import current_app

from app.clients.sms import SmsClient, SmsClientResponseException


class FiretextStubClientException(SmsClientResponseException):
pass


class FiretextStubClient(SmsClient):
"""
Firetext "stub" SMS client for sending SMS to a testing stub.

This class is not thread-safe.
"""

name = "firetext-stub"

def __init__(self, current_app, statsd_client, stub_url):
super().__init__(current_app, statsd_client)
self.url = stub_url
self.requests_session = requests.Session()

def try_send_sms(self, to, content, reference, international, sender):
"""
Send SMS to the Firetext stub endpoint.
"""
data = {
"from": sender,
"to": to,
"message": content,
"reference": reference,
}

try:
start_time = monotonic()
response = self.requests_session.request("POST", self.url, data=data, timeout=60)
response.raise_for_status()

try:
response_json = json.loads(response.text)
if response_json.get("code") != 0:
raise ValueError("Expected 'code' to be '0'")
except (ValueError, AttributeError, KeyError) as e:
raise FiretextStubClientException("Invalid response JSON from stub") from e

except Exception as e:
self.statsd_client.incr("clients.firetext_stub.error")
raise FiretextStubClientException(str(e)) from e
else:
elapsed_time = monotonic() - start_time
current_app.logger.info(
"Firetext stub request finished in %.4g seconds", elapsed_time, {"duration": elapsed_time}
)
self.statsd_client.timing("clients.firetext_stub.request-time", elapsed_time)
self.statsd_client.incr("clients.firetext_stub.success")
60 changes: 60 additions & 0 deletions app/clients/sms/mmg_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import json
from time import monotonic

import requests
from flask import current_app

from app.clients.sms import SmsClient, SmsClientResponseException


class MMGStubClientException(SmsClientResponseException):
pass


class MMGStubClient(SmsClient):
"""
MMG "stub" SMS client for sending SMS to a testing stub.

This class is not thread-safe.
"""

name = "mmg-stub"

def __init__(self, current_app, statsd_client, stub_url):
super().__init__(current_app, statsd_client)
self.url = stub_url
self.requests_session = requests.Session()

def try_send_sms(self, to, content, reference, international, sender):
"""
Send SMS to the MMG stub endpoint.
"""
data = {
"sender": sender,
"to": to,
"message": content,
"reference": reference,
}

try:
start_time = monotonic()
response = self.requests_session.request("POST", self.url, data=data, timeout=60)
response.raise_for_status()

try:
response_json = json.loads(response.text)
if "reference" not in response_json:
raise ValueError("Expected 'reference' in response")
except (ValueError, AttributeError, KeyError) as e:
raise MMGStubClientException("Invalid response JSON from stub") from e

except Exception as e:
self.statsd_client.incr("clients.mmg_stub.error")
raise MMGStubClientException(str(e)) from e
else:
elapsed_time = monotonic() - start_time
current_app.logger.info(
"MMG stub request finished in %.4g seconds", elapsed_time, {"duration": elapsed_time}
)
self.statsd_client.timing("clients.mmg_stub.request-time", elapsed_time)
self.statsd_client.incr("clients.mmg_stub.success")
4 changes: 4 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ class Config:
MMG_URL = os.environ.get("MMG_URL", "https://api.mmg.co.uk/jsonv2a/api.php")
FIRETEXT_URL = os.environ.get("FIRETEXT_URL", "https://www.firetext.co.uk/api/sendsms/json")
SES_STUB_URL = os.environ.get("SES_STUB_URL")
FIRETEXT_STUB_URL = os.environ.get("FIRETEXT_STUB_URL")
MMG_STUB_URL = os.environ.get("MMG_STUB_URL")
LETTER_STUB_ENABLED = os.environ.get("LETTER_STUB_ENABLED", "0") == "1"
PROVIDER_OPTION_ENABLED = os.environ.get("PROVIDER_OPTION_ENABLED", "0") == "1"

DVLA_API_BASE_URL = os.environ.get("DVLA_API_BASE_URL", "https://uat.driver-vehicle-licensing.api.gov.uk")
DVLA_API_TLS_CIPHERS = os.environ.get("DVLA_API_TLS_CIPHERS")
Expand Down
1 change: 1 addition & 0 deletions app/dao/notifications_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"created_by_id",
"postage",
"document_download_count",
"provider_requested",
]


Expand Down
33 changes: 30 additions & 3 deletions app/delivery/send_to_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from app.exceptions import NotificationTechnicalFailureException
from app.models import Notification
from app.provider_selection import get_allowed_providers
from app.serialised_models import SerialisedProviders, SerialisedService, SerialisedTemplate


Expand All @@ -45,7 +46,9 @@ def send_sms_to_provider(notification: Notification) -> None:
return

if notification.status == "created":
provider = provider_to_use(SMS_TYPE, notification.international)
provider = provider_to_use(
SMS_TYPE, notification.international, provider_requested=notification.provider_requested
)

template_model = SerialisedTemplate.from_id_service_id_and_version(
template_id=notification.template_id, service_id=service.id, version=notification.template_version
Expand Down Expand Up @@ -136,7 +139,7 @@ def send_email_to_provider(notification):
technical_failure(notification=notification)
return
if notification.status == "created":
provider = provider_to_use(EMAIL_TYPE)
provider = provider_to_use(EMAIL_TYPE, provider_requested=notification.provider_requested)

template = SerialisedTemplate.from_id_service_id_and_version(
template_id=notification.template_id, service_id=service.id, version=notification.template_version
Expand Down Expand Up @@ -197,7 +200,31 @@ def update_notification_to_sending(notification, provider):
dao_update_notification(notification)


def provider_to_use(notification_type, international=False):
def provider_to_use(notification_type, international=False, provider_requested=None):
# If a provider was explicitly requested, enforce availability and configuration.
if provider_requested:
allowed = get_allowed_providers(notification_type, international=international)
if provider_requested not in allowed:
current_app.logger.error(
"Requested provider %s is not available for %s notifications",
provider_requested,
notification_type,
extra={"notification_type": notification_type, "provider_requested": provider_requested},
)
raise Exception(f"Requested provider {provider_requested} is not available for {notification_type}")

provider = notification_provider_clients.get_client_by_name_and_type(provider_requested, notification_type)
if not provider:
current_app.logger.error(
"Requested provider %s is not configured for %s notifications",
provider_requested,
notification_type,
extra={"notification_type": notification_type, "provider_requested": provider_requested},
)
raise Exception(f"Requested provider {provider_requested} is not configured for {notification_type}")

return provider

active_providers = [
p for p in SerialisedProviders.from_notification_type(notification_type, international) if p.active
]
Expand Down
Loading