diff --git a/.github/workflows/base-lambdas-reusable-deploy-all.yml b/.github/workflows/base-lambdas-reusable-deploy-all.yml index 96714ee89..119ea3193 100644 --- a/.github/workflows/base-lambdas-reusable-deploy-all.yml +++ b/.github/workflows/base-lambdas-reusable-deploy-all.yml @@ -691,8 +691,8 @@ jobs: python_version: ${{ inputs.python_version }} build_branch: ${{ inputs.build_branch }} sandbox: ${{ inputs.sandbox }} - lambda_handler_name: transfer_kill_switch_handler - lambda_aws_name: TransferKillSwitch + lambda_handler_name: transfer_family_kill_switch_handler + lambda_aws_name: TransferFamilyKillSwitch lambda_layer_names: "core_lambda_layer" secrets: AWS_ASSUME_ROLE: ${{ secrets.AWS_ASSUME_ROLE }} diff --git a/lambdas/handlers/transfer_kill_switch_handler.py b/lambdas/handlers/transfer_family_kill_switch_handler.py similarity index 78% rename from lambdas/handlers/transfer_kill_switch_handler.py rename to lambdas/handlers/transfer_family_kill_switch_handler.py index 5c33b3d3b..8baccdbfd 100644 --- a/lambdas/handlers/transfer_kill_switch_handler.py +++ b/lambdas/handlers/transfer_family_kill_switch_handler.py @@ -1,4 +1,4 @@ -from services.expedite_kill_switch_service import ExpediteKillSwitchService +from services.expedite_transfer_family_kill_switch_service import ExpediteKillSwitchService from utils.decorators.handle_lambda_exceptions import handle_lambda_exceptions from utils.decorators.set_audit_arg import set_request_context_for_logging diff --git a/lambdas/services/expedite_kill_switch_service.py b/lambdas/services/expedite_transfer_family_kill_switch_service.py similarity index 79% rename from lambdas/services/expedite_kill_switch_service.py rename to lambdas/services/expedite_transfer_family_kill_switch_service.py index 540adfe1b..b3998d2bd 100644 --- a/lambdas/services/expedite_kill_switch_service.py +++ b/lambdas/services/expedite_transfer_family_kill_switch_service.py @@ -2,6 +2,7 @@ import os import boto3 + from utils.audit_logging_setup import LoggingService logger = LoggingService(__name__) @@ -15,10 +16,10 @@ def response(message: str): "body": json.dumps({"message": message}), } - class ExpediteKillSwitchService: def __init__(self): self.transfer_client = boto3.client("transfer") + self.cloudwatch = boto3.client("cloudwatch") self.staging_bucket = os.environ.get("STAGING_STORE_BUCKET_NAME", "") self.workspace = os.environ.get("WORKSPACE", "") @@ -36,7 +37,7 @@ def handle_sns_event(self, event: dict): "body": json.dumps( { "message": ( - "Kill switch disabled – no Transfer server ID discovered" + "Transfer family kill switch disabled – no Transfer server ID discovered" ) } ), @@ -123,8 +124,9 @@ def is_quarantine_expedite(self, bucket: str, key: str) -> bool: return False quarantine_prefix = f"{self.staging_bucket}/expedite/" - return bucket.startswith("cloudstoragesecquarantine-") and key.startswith( - quarantine_prefix + return ( + bucket.startswith("cloudstoragesecquarantine-") + and key.startswith(quarantine_prefix) ) def get_transfer_server_id(self) -> str: @@ -182,10 +184,19 @@ def stop_transfer_family_server(self, server_id: str): ) self.transfer_client.stop_server(ServerId=server_id) - logger.warning( f"Transfer Family server {server_id} STOPPED due to virus scan trigger" ) + try: + self.report_kill_switch_activated(server_id=server_id) + except Exception as metric_exc: + logger.error( + f"Failed to publish kill switch metric: {metric_exc}," + f" leading to failing to inform that kill switch has been activated" + ) + return response( + f"Server {server_id} stopped, but failed to alert the team" + ) return response(f"Server {server_id} stopped") except self.transfer_client.exceptions.ResourceNotFoundException: @@ -195,3 +206,28 @@ def stop_transfer_family_server(self, server_id: str): except Exception as exc: logger.error(f"Failed to stop Transfer Family server: {exc}") return response("Failed to stop server") + + def report_kill_switch_activated(self, server_id: str): + try: + self.cloudwatch.put_metric_data( + Namespace="Custom/TransferFamilyKillSwitch", + MetricData=[ + { + "MetricName": "ServerStopped", + "Dimensions": [ + {"Name": "Workspace", "Value": self.workspace or "unknown"}, + ], + "Value": 1.0, + "Unit": "Count", + } + ], + ) + except Exception as metric_exc: + logger.error( + f"Failed to publish kill switch metric: {metric_exc}," + f" leading to failing to inform that kill switch has been activated" + ) + + logger.warning( + f"Transfer Family server {server_id} STOPPED due to infected expedite upload" + ) diff --git a/lambdas/tests/unit/handlers/test_transfer_kill_switch_handler.py b/lambdas/tests/unit/handlers/test_transfer_kill_switch_handler.py index 72ad4030b..074b5733e 100644 --- a/lambdas/tests/unit/handlers/test_transfer_kill_switch_handler.py +++ b/lambdas/tests/unit/handlers/test_transfer_kill_switch_handler.py @@ -1,12 +1,12 @@ import pytest -from handlers.transfer_kill_switch_handler import lambda_handler +from handlers.transfer_family_kill_switch_handler import lambda_handler @pytest.fixture def mock_service(mocker): service_instance = mocker.Mock() mocker.patch( - "handlers.transfer_kill_switch_handler.ExpediteKillSwitchService", + "handlers.transfer_family_kill_switch_handler.ExpediteKillSwitchService", return_value=service_instance, ) return service_instance diff --git a/lambdas/tests/unit/services/test_expedite_kill_switch_service.py b/lambdas/tests/unit/services/test_expedite_kill_switch_service.py index 8e6fdaeaa..bed4631ea 100644 --- a/lambdas/tests/unit/services/test_expedite_kill_switch_service.py +++ b/lambdas/tests/unit/services/test_expedite_kill_switch_service.py @@ -1,7 +1,7 @@ import json import pytest -from services.expedite_kill_switch_service import ( +from services.expedite_transfer_family_kill_switch_service import ( ExpediteKillSwitchService, EXPECTED_SCAN_RESULTS, response, @@ -11,7 +11,7 @@ @pytest.fixture def mock_transfer_client(mocker): transfer_client = mocker.Mock() - + cloudwatch_client = mocker.Mock() class ResourceNotFoundException(Exception): pass @@ -22,10 +22,12 @@ class ResourceNotFoundException(Exception): def _client(service_name, *_, **__): if service_name == "transfer": return transfer_client + if service_name == "cloudwatch": + return cloudwatch_client raise AssertionError(f"Unexpected boto3 client requested: {service_name}") mocker.patch( - "services.expedite_kill_switch_service.boto3.client", + "services.expedite_transfer_family_kill_switch_service.boto3.client", side_effect=_client, ) return transfer_client @@ -93,7 +95,7 @@ def test_handle_sns_event_no_servers_disables_kill_switch( assert ( extract_message(resp) - == "Kill switch disabled – no Transfer server ID discovered" + == "Transfer family kill switch disabled – no Transfer server ID discovered" ) mock_transfer_client.describe_server.assert_not_called() mock_transfer_client.stop_server.assert_not_called() @@ -243,4 +245,124 @@ def test_stop_transfer_family_server_handles_generic_exception( resp = service.stop_transfer_family_server("srv-error") mock_transfer_client.stop_server.assert_not_called() - assert extract_message(resp) == "Failed to stop server" \ No newline at end of file + assert extract_message(resp) == "Failed to stop server" + +def test_handle_scan_message_ignores_irrelevant_scan_result(service, mocker): + message = { + "scanResult": "Clean", + "bucket": "cloudstoragesecquarantine-abc", + "key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf", + } + + mock_stop = mocker.patch.object( + service, "stop_transfer_family_server", autospec=True + ) + + resp = service.handle_scan_message(server_id="srv-abc", message=message) + + assert extract_message(resp) == "Scan result not relevant, no action taken" + mock_stop.assert_not_called() + + +def test_handle_scan_message_returns_invalid_payload_when_bucket_missing(service, mocker): + message = { + "scanResult": "Infected", + "key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf", + } + + mock_stop = mocker.patch.object( + service, "stop_transfer_family_server", autospec=True + ) + + resp = service.handle_scan_message(server_id="srv-abc", message=message) + + assert extract_message(resp) == "Invalid payload (missing bucket/key)" + mock_stop.assert_not_called() + + +def test_handle_scan_message_returns_invalid_payload_when_key_missing(service, mocker): + message = { + "scanResult": "Infected", + "bucket": "cloudstoragesecquarantine-abc", + } + + mock_stop = mocker.patch.object( + service, "stop_transfer_family_server", autospec=True + ) + + resp = service.handle_scan_message(server_id="srv-abc", message=message) + + assert extract_message(resp) == "Invalid payload (missing bucket/key)" + mock_stop.assert_not_called() + + +def test_handle_scan_message_not_quarantine_expedite(service, mocker): + message = { + "scanResult": "Infected", + "bucket": "some-other-bucket", + "key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf", + } + + mock_stop = mocker.patch.object( + service, "stop_transfer_family_server", autospec=True + ) + + resp = service.handle_scan_message(server_id="srv-abc", message=message) + + assert extract_message(resp) == "Not an expedite file, no action taken" + mock_stop.assert_not_called() + + +def test_handle_scan_message_non_infected_expedite(service, mocker): + message = { + "scanResult": "Error", + "bucket": "cloudstoragesecquarantine-abc", + "key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf", + } + + mock_stop = mocker.patch.object( + service, "stop_transfer_family_server", autospec=True + ) + + resp = service.handle_scan_message(server_id="srv-abc", message=message) + + assert ( + extract_message(resp) + == "Non-infected result for expedite file, no kill switch action" + ) + mock_stop.assert_not_called() + +def test_extract_sns_message_returns_none_on_invalid_json(service): + event = { + "Records": [ + { + "Sns": { + "Message": "not-json-at-all", + } + } + ] + } + + msg = service.extract_sns_message(event) + + assert msg is None + +def test_stop_transfer_family_server_handles_metric_failure( + service, mock_transfer_client, mocker +): + mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}} + + mocker.patch.object( + service, + "report_kill_switch_activated", + side_effect=Exception("metric failed"), + ) + + resp = service.stop_transfer_family_server("srv-xyz") + + mock_transfer_client.describe_server.assert_called_once_with(ServerId="srv-xyz") + mock_transfer_client.stop_server.assert_called_once_with(ServerId="srv-xyz") + assert ( + extract_message(resp) + == "Server srv-xyz stopped, but failed to alert the team" + )