Skip to content
4 changes: 2 additions & 2 deletions .github/workflows/base-lambdas-reusable-deploy-all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -691,8 +691,8 @@ jobs:
python_version: ${{ inputs.python_version }}
build_branch: ${{ inputs.build_branch }}
sandbox: ${{ inputs.sandbox }}
lambda_handler_name: transfer_kill_switch_handler
lambda_aws_name: TransferKillSwitch
lambda_handler_name: transfer_family_kill_switch_handler
lambda_aws_name: TransferFamilyKillSwitch
lambda_layer_names: "core_lambda_layer"
secrets:
AWS_ASSUME_ROLE: ${{ secrets.AWS_ASSUME_ROLE }}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from services.expedite_kill_switch_service import ExpediteKillSwitchService
from services.expedite_transfer_family_kill_switch_service import ExpediteKillSwitchService
from utils.decorators.handle_lambda_exceptions import handle_lambda_exceptions
from utils.decorators.set_audit_arg import set_request_context_for_logging

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

import boto3

from utils.audit_logging_setup import LoggingService

logger = LoggingService(__name__)
Expand All @@ -15,10 +16,10 @@ def response(message: str):
"body": json.dumps({"message": message}),
}


class ExpediteKillSwitchService:
def __init__(self):
self.transfer_client = boto3.client("transfer")
self.cloudwatch = boto3.client("cloudwatch")

self.staging_bucket = os.environ.get("STAGING_STORE_BUCKET_NAME", "")
self.workspace = os.environ.get("WORKSPACE", "")
Expand All @@ -36,7 +37,7 @@ def handle_sns_event(self, event: dict):
"body": json.dumps(
{
"message": (
"Kill switch disabled – no Transfer server ID discovered"
"Transfer family kill switch disabled – no Transfer server ID discovered"
)
}
),
Expand Down Expand Up @@ -123,8 +124,9 @@ def is_quarantine_expedite(self, bucket: str, key: str) -> bool:
return False

quarantine_prefix = f"{self.staging_bucket}/expedite/"
return bucket.startswith("cloudstoragesecquarantine-") and key.startswith(
quarantine_prefix
return (
bucket.startswith("cloudstoragesecquarantine-")
and key.startswith(quarantine_prefix)
)

def get_transfer_server_id(self) -> str:
Expand Down Expand Up @@ -182,10 +184,19 @@ def stop_transfer_family_server(self, server_id: str):
)

self.transfer_client.stop_server(ServerId=server_id)

logger.warning(
f"Transfer Family server {server_id} STOPPED due to virus scan trigger"
)
try:
self.report_kill_switch_activated(server_id=server_id)
except Exception as metric_exc:
logger.error(
f"Failed to publish kill switch metric: {metric_exc},"
f" leading to failing to inform that kill switch has been activated"
)
return response(
f"Server {server_id} stopped, but failed to alert the team"
)
return response(f"Server {server_id} stopped")

except self.transfer_client.exceptions.ResourceNotFoundException:
Expand All @@ -195,3 +206,28 @@ def stop_transfer_family_server(self, server_id: str):
except Exception as exc:
logger.error(f"Failed to stop Transfer Family server: {exc}")
return response("Failed to stop server")

def report_kill_switch_activated(self, server_id: str):
try:
self.cloudwatch.put_metric_data(
Namespace="Custom/TransferFamilyKillSwitch",
MetricData=[
{
"MetricName": "ServerStopped",
"Dimensions": [
{"Name": "Workspace", "Value": self.workspace or "unknown"},
],
"Value": 1.0,
"Unit": "Count",
}
],
)
except Exception as metric_exc:
logger.error(
f"Failed to publish kill switch metric: {metric_exc},"
f" leading to failing to inform that kill switch has been activated"
)

logger.warning(
f"Transfer Family server {server_id} STOPPED due to infected expedite upload"
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pytest
from handlers.transfer_kill_switch_handler import lambda_handler
from handlers.transfer_family_kill_switch_handler import lambda_handler


@pytest.fixture
def mock_service(mocker):
service_instance = mocker.Mock()
mocker.patch(
"handlers.transfer_kill_switch_handler.ExpediteKillSwitchService",
"handlers.transfer_family_kill_switch_handler.ExpediteKillSwitchService",
return_value=service_instance,
)
return service_instance
Expand Down
132 changes: 127 additions & 5 deletions lambdas/tests/unit/services/test_expedite_kill_switch_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import pytest

from services.expedite_kill_switch_service import (
from services.expedite_transfer_family_kill_switch_service import (
ExpediteKillSwitchService,
EXPECTED_SCAN_RESULTS,
response,
Expand All @@ -11,7 +11,7 @@
@pytest.fixture
def mock_transfer_client(mocker):
transfer_client = mocker.Mock()

cloudwatch_client = mocker.Mock()
class ResourceNotFoundException(Exception):
pass

Expand All @@ -22,10 +22,12 @@ class ResourceNotFoundException(Exception):
def _client(service_name, *_, **__):
if service_name == "transfer":
return transfer_client
if service_name == "cloudwatch":
return cloudwatch_client
raise AssertionError(f"Unexpected boto3 client requested: {service_name}")

mocker.patch(
"services.expedite_kill_switch_service.boto3.client",
"services.expedite_transfer_family_kill_switch_service.boto3.client",
side_effect=_client,
)
return transfer_client
Expand Down Expand Up @@ -93,7 +95,7 @@ def test_handle_sns_event_no_servers_disables_kill_switch(

assert (
extract_message(resp)
== "Kill switch disabled – no Transfer server ID discovered"
== "Transfer family kill switch disabled – no Transfer server ID discovered"
)
mock_transfer_client.describe_server.assert_not_called()
mock_transfer_client.stop_server.assert_not_called()
Expand Down Expand Up @@ -243,4 +245,124 @@ def test_stop_transfer_family_server_handles_generic_exception(
resp = service.stop_transfer_family_server("srv-error")

mock_transfer_client.stop_server.assert_not_called()
assert extract_message(resp) == "Failed to stop server"
assert extract_message(resp) == "Failed to stop server"

def test_handle_scan_message_ignores_irrelevant_scan_result(service, mocker):
message = {
"scanResult": "Clean",
"bucket": "cloudstoragesecquarantine-abc",
"key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf",
}

mock_stop = mocker.patch.object(
service, "stop_transfer_family_server", autospec=True
)

resp = service.handle_scan_message(server_id="srv-abc", message=message)

assert extract_message(resp) == "Scan result not relevant, no action taken"
mock_stop.assert_not_called()


def test_handle_scan_message_returns_invalid_payload_when_bucket_missing(service, mocker):
message = {
"scanResult": "Infected",
"key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf",
}

mock_stop = mocker.patch.object(
service, "stop_transfer_family_server", autospec=True
)

resp = service.handle_scan_message(server_id="srv-abc", message=message)

assert extract_message(resp) == "Invalid payload (missing bucket/key)"
mock_stop.assert_not_called()


def test_handle_scan_message_returns_invalid_payload_when_key_missing(service, mocker):
message = {
"scanResult": "Infected",
"bucket": "cloudstoragesecquarantine-abc",
}

mock_stop = mocker.patch.object(
service, "stop_transfer_family_server", autospec=True
)

resp = service.handle_scan_message(server_id="srv-abc", message=message)

assert extract_message(resp) == "Invalid payload (missing bucket/key)"
mock_stop.assert_not_called()


def test_handle_scan_message_not_quarantine_expedite(service, mocker):
message = {
"scanResult": "Infected",
"bucket": "some-other-bucket",
"key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf",
}

mock_stop = mocker.patch.object(
service, "stop_transfer_family_server", autospec=True
)

resp = service.handle_scan_message(server_id="srv-abc", message=message)

assert extract_message(resp) == "Not an expedite file, no action taken"
mock_stop.assert_not_called()


def test_handle_scan_message_non_infected_expedite(service, mocker):
message = {
"scanResult": "Error",
"bucket": "cloudstoragesecquarantine-abc",
"key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf",
}

mock_stop = mocker.patch.object(
service, "stop_transfer_family_server", autospec=True
)

resp = service.handle_scan_message(server_id="srv-abc", message=message)

assert (
extract_message(resp)
== "Non-infected result for expedite file, no kill switch action"
)
mock_stop.assert_not_called()

def test_extract_sns_message_returns_none_on_invalid_json(service):
event = {
"Records": [
{
"Sns": {
"Message": "not-json-at-all",
}
}
]
}

msg = service.extract_sns_message(event)

assert msg is None

def test_stop_transfer_family_server_handles_metric_failure(
service, mock_transfer_client, mocker
):
mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}}

mocker.patch.object(
service,
"report_kill_switch_activated",
side_effect=Exception("metric failed"),
)

resp = service.stop_transfer_family_server("srv-xyz")

mock_transfer_client.describe_server.assert_called_once_with(ServerId="srv-xyz")
mock_transfer_client.stop_server.assert_called_once_with(ServerId="srv-xyz")
assert (
extract_message(resp)
== "Server srv-xyz stopped, but failed to alert the team"
)
Loading