diff --git a/.github/workflows/base-lambdas-reusable-deploy-all.yml b/.github/workflows/base-lambdas-reusable-deploy-all.yml index ef061f863..068ae6fc3 100644 --- a/.github/workflows/base-lambdas-reusable-deploy-all.yml +++ b/.github/workflows/base-lambdas-reusable-deploy-all.yml @@ -697,6 +697,20 @@ jobs: secrets: AWS_ASSUME_ROLE: ${{ secrets.AWS_ASSUME_ROLE }} + deploy_transfer_kill_switch_lambda: + name: Deploy transfer kill switch lambda + uses: ./.github/workflows/base-lambdas-reusable-deploy.yml + with: + environment: ${{ inputs.environment }} + python_version: ${{ inputs.python_version }} + build_branch: ${{ inputs.build_branch }} + sandbox: ${{ inputs.sandbox }} + lambda_handler_name: transfer_family_kill_switch_handler + lambda_aws_name: TransferFamilyKillSwitch + lambda_layer_names: "core_lambda_layer" + secrets: + AWS_ASSUME_ROLE: ${{ secrets.AWS_ASSUME_ROLE }} + deploy_search_document_review_lambda: name: Deploy Search Document Review uses: ./.github/workflows/base-lambdas-reusable-deploy.yml diff --git a/lambdas/handlers/transfer_family_kill_switch_handler.py b/lambdas/handlers/transfer_family_kill_switch_handler.py new file mode 100644 index 000000000..8baccdbfd --- /dev/null +++ b/lambdas/handlers/transfer_family_kill_switch_handler.py @@ -0,0 +1,10 @@ +from services.expedite_transfer_family_kill_switch_service import ExpediteKillSwitchService +from utils.decorators.handle_lambda_exceptions import handle_lambda_exceptions +from utils.decorators.set_audit_arg import set_request_context_for_logging + + +@handle_lambda_exceptions +@set_request_context_for_logging +def lambda_handler(event, context): + service = ExpediteKillSwitchService() + return service.handle_sns_event(event) diff --git a/lambdas/services/expedite_transfer_family_kill_switch_service.py b/lambdas/services/expedite_transfer_family_kill_switch_service.py new file mode 100644 index 000000000..b3998d2bd --- /dev/null +++ b/lambdas/services/expedite_transfer_family_kill_switch_service.py @@ -0,0 +1,233 @@ +import json +import os + +import boto3 + +from utils.audit_logging_setup import LoggingService + +logger = LoggingService(__name__) + +EXPECTED_SCAN_RESULTS = {"Infected", "Error", "Unscannable", "Suspicious"} + + +def response(message: str): + return { + "statusCode": 200, + "body": json.dumps({"message": message}), + } + +class ExpediteKillSwitchService: + def __init__(self): + self.transfer_client = boto3.client("transfer") + self.cloudwatch = boto3.client("cloudwatch") + + self.staging_bucket = os.environ.get("STAGING_STORE_BUCKET_NAME", "") + self.workspace = os.environ.get("WORKSPACE", "") + + def handle_sns_event(self, event: dict): + logger.info("Received SNS virus scan notification event", {"event": event}) + + server_id = self.get_transfer_server_id() + if not server_id: + logger.warning( + "No Transfer Family server ID resolved from AWS – kill switch disabled." + ) + return { + "statusCode": 200, + "body": json.dumps( + { + "message": ( + "Transfer family kill switch disabled – no Transfer server ID discovered" + ) + } + ), + } + + logger.warning( + "Initiating Transfer Family shutdown.", + { + "server_id": server_id, + "workspace": self.workspace, + }, + ) + + return self.stop_transfer_family_server(server_id) + + def handle_scan_message(self, server_id: str, message: dict): + scan_result = message.get("scanResult") + bucket = message.get("bucket") + key = message.get("key") + + if not self.is_relevant_scan_result(scan_result): + logger.info( + f"Ignoring scan result '{scan_result}' – not one of {EXPECTED_SCAN_RESULTS}" + ) + return response("Scan result not relevant, no action taken") + + if not self.has_required_fields(bucket, key): + logger.error("SNS payload missing required 'bucket' or 'key' fields") + return response("Invalid payload (missing bucket/key)") + + if not self.is_quarantine_expedite(bucket, key): + logger.info( + "Scan notification is not for an expedite file – no kill switch action", + { + "bucket": bucket, + "key": key, + "staging_bucket": self.staging_bucket, + "workspace": self.workspace, + }, + ) + return response("Not an expedite file, no action taken") + + if scan_result != "Infected": + logger.warning( + "Non-clean scan result for expedite file, but not 'Infected' – no kill switch action", + { + "scanResult": scan_result, + "bucket": bucket, + "key": key, + "workspace": self.workspace, + }, + ) + return response( + "Non-infected result for expedite file, no kill switch action" + ) + + logger.warning( + "Initiating Transfer Family shutdown.", + { + "server_id": server_id, + "bucket": bucket, + "key": key, + "scanResult": scan_result, + "workspace": self.workspace, + }, + ) + + return self.stop_transfer_family_server(server_id) + + def is_relevant_scan_result(self, scan_result: str) -> bool: + return scan_result in EXPECTED_SCAN_RESULTS + + def has_required_fields(self, bucket: str, key: str) -> bool: + return bool(bucket and key) + + def is_quarantine_expedite(self, bucket: str, key: str) -> bool: + """ + Example quarantine: + bucket = cloudstoragesecquarantine-... + key = "pre-prod-staging-bulk-store/expedite/..." + Where key starts with "-staging-bulk-store/expedite/" + """ + if not self.staging_bucket: + return False + + quarantine_prefix = f"{self.staging_bucket}/expedite/" + return ( + bucket.startswith("cloudstoragesecquarantine-") + and key.startswith(quarantine_prefix) + ) + + def get_transfer_server_id(self) -> str: + """ + Discover Transfer Family servers in this account/region and return + the first ServerId, or "" if none exist or an error occurs. + """ + try: + resp = self.transfer_client.list_servers(MaxResults=1) + servers = resp.get("Servers", []) + if not servers: + logger.warning( + "No AWS Transfer Family servers found in account/region " + "– kill switch disabled." + ) + return "" + + server_id = servers[0]["ServerId"].strip() + logger.info( + "Resolved Transfer server ID via list_servers", + {"server_id": server_id}, + ) + return server_id + + except Exception as exc: + logger.error(f"Failed to list Transfer Family servers: {exc}") + return "" + + def extract_sns_message(self, event): + try: + records = event.get("Records") + if not records: + return None + + sns_record = records[0].get("Sns") + if not sns_record: + return None + + raw_message = sns_record.get("Message") + if not raw_message: + return None + + return json.loads(raw_message) + + except Exception as exc: + logger.error(f"Failed to parse SNS message: {exc}") + return None + + def stop_transfer_family_server(self, server_id: str): + try: + desc = self.transfer_client.describe_server(ServerId=server_id) + logger.info( + "Transfer Family server found", + {"server_id": server_id, "state": desc["Server"]["State"]}, + ) + + self.transfer_client.stop_server(ServerId=server_id) + logger.warning( + f"Transfer Family server {server_id} STOPPED due to virus scan trigger" + ) + try: + self.report_kill_switch_activated(server_id=server_id) + except Exception as metric_exc: + logger.error( + f"Failed to publish kill switch metric: {metric_exc}," + f" leading to failing to inform that kill switch has been activated" + ) + return response( + f"Server {server_id} stopped, but failed to alert the team" + ) + return response(f"Server {server_id} stopped") + + except self.transfer_client.exceptions.ResourceNotFoundException: + logger.error(f"Transfer Family server '{server_id}' not found") + return response("Server not found") + + except Exception as exc: + logger.error(f"Failed to stop Transfer Family server: {exc}") + return response("Failed to stop server") + + def report_kill_switch_activated(self, server_id: str): + try: + self.cloudwatch.put_metric_data( + Namespace="Custom/TransferFamilyKillSwitch", + MetricData=[ + { + "MetricName": "ServerStopped", + "Dimensions": [ + {"Name": "Workspace", "Value": self.workspace or "unknown"}, + ], + "Value": 1.0, + "Unit": "Count", + } + ], + ) + except Exception as metric_exc: + logger.error( + f"Failed to publish kill switch metric: {metric_exc}," + f" leading to failing to inform that kill switch has been activated" + ) + + logger.warning( + f"Transfer Family server {server_id} STOPPED due to infected expedite upload" + ) diff --git a/lambdas/tests/unit/handlers/test_transfer_kill_switch_handler.py b/lambdas/tests/unit/handlers/test_transfer_kill_switch_handler.py new file mode 100644 index 000000000..074b5733e --- /dev/null +++ b/lambdas/tests/unit/handlers/test_transfer_kill_switch_handler.py @@ -0,0 +1,31 @@ +import pytest +from handlers.transfer_family_kill_switch_handler import lambda_handler + + +@pytest.fixture +def mock_service(mocker): + service_instance = mocker.Mock() + mocker.patch( + "handlers.transfer_family_kill_switch_handler.ExpediteKillSwitchService", + return_value=service_instance, + ) + return service_instance + + +@pytest.fixture +def context(mocker): + context = mocker.Mock() + context.aws_request_id = "test-request-id" + return context + + +def test_lambda_handler_delegates_to_service_handle_sns_event(mock_service, context): + event = {"Records": []} + expected_response = {"statusCode": 200, "body": '{"message": "ok"}'} + + mock_service.handle_sns_event.return_value = expected_response + + resp = lambda_handler(event, context) + + mock_service.handle_sns_event.assert_called_once_with(event) + assert resp == expected_response diff --git a/lambdas/tests/unit/services/test_bulk_upload_metadata_processor_service.py b/lambdas/tests/unit/services/test_bulk_upload_metadata_processor_service.py index 7d90ab8d2..ab11ea59e 100644 --- a/lambdas/tests/unit/services/test_bulk_upload_metadata_processor_service.py +++ b/lambdas/tests/unit/services/test_bulk_upload_metadata_processor_service.py @@ -8,9 +8,10 @@ import pytest from botocore.exceptions import ClientError +from freezegun import freeze_time + from enums.upload_status import UploadStatus from enums.virus_scan_result import VirusScanResult -from freezegun import freeze_time from models.staging_metadata import ( METADATA_FILENAME, BulkUploadQueueMetadata, diff --git a/lambdas/tests/unit/services/test_expedite_kill_switch_service.py b/lambdas/tests/unit/services/test_expedite_kill_switch_service.py new file mode 100644 index 000000000..bed4631ea --- /dev/null +++ b/lambdas/tests/unit/services/test_expedite_kill_switch_service.py @@ -0,0 +1,368 @@ +import json +import pytest + +from services.expedite_transfer_family_kill_switch_service import ( + ExpediteKillSwitchService, + EXPECTED_SCAN_RESULTS, + response, +) + + +@pytest.fixture +def mock_transfer_client(mocker): + transfer_client = mocker.Mock() + cloudwatch_client = mocker.Mock() + class ResourceNotFoundException(Exception): + pass + + transfer_client.exceptions = mocker.Mock( + ResourceNotFoundException=ResourceNotFoundException + ) + + def _client(service_name, *_, **__): + if service_name == "transfer": + return transfer_client + if service_name == "cloudwatch": + return cloudwatch_client + raise AssertionError(f"Unexpected boto3 client requested: {service_name}") + + mocker.patch( + "services.expedite_transfer_family_kill_switch_service.boto3.client", + side_effect=_client, + ) + return transfer_client + + +@pytest.fixture +def service(mock_transfer_client, monkeypatch): + monkeypatch.setenv("STAGING_STORE_BUCKET_NAME", "pre-prod-staging-bulk-store") + monkeypatch.setenv("WORKSPACE", "pre-prod") + return ExpediteKillSwitchService() + + +@pytest.fixture +def sns_event(): + message = { + "scanResult": "Infected", + "bucket": "cloudstoragesecquarantine-abc", + "key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf", + } + return { + "Records": [ + { + "Sns": { + "Message": json.dumps(message), + } + } + ] + } + + +def extract_message(resp): + return json.loads(resp["body"])["message"] + +def test_response_builds_expected_http_shape(): + msg = "hello world" + resp = response(msg) + + assert resp["statusCode"] == 200 + body = json.loads(resp["body"]) + assert body == {"message": msg} + + +def test_handle_sns_event_happy_path_infected_expedite( + service, sns_event, mock_transfer_client +): + mock_transfer_client.list_servers.return_value = { + "Servers": [{"ServerId": "srv-12345"}] + } + mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}} + + resp = service.handle_sns_event(sns_event) + + mock_transfer_client.list_servers.assert_called_once_with(MaxResults=1) + mock_transfer_client.describe_server.assert_called_once_with(ServerId="srv-12345") + mock_transfer_client.stop_server.assert_called_once_with(ServerId="srv-12345") + assert extract_message(resp) == "Server srv-12345 stopped" + + +def test_handle_sns_event_no_servers_disables_kill_switch( + service, sns_event, mock_transfer_client +): + mock_transfer_client.list_servers.return_value = {"Servers": []} + + resp = service.handle_sns_event(sns_event) + + assert ( + extract_message(resp) + == "Transfer family kill switch disabled – no Transfer server ID discovered" + ) + mock_transfer_client.describe_server.assert_not_called() + mock_transfer_client.stop_server.assert_not_called() + + +def test_get_transfer_server_id_happy_path_reads_from_list_servers( + service, mock_transfer_client +): + mock_transfer_client.list_servers.return_value = { + "Servers": [{"ServerId": " srv-9999 "}] + } + + server_id = service.get_transfer_server_id() + + mock_transfer_client.list_servers.assert_called_once_with(MaxResults=1) + assert server_id == " srv-9999 ".strip() + + +def test_get_transfer_server_id_returns_empty_when_no_servers( + service, mock_transfer_client +): + mock_transfer_client.list_servers.return_value = {"Servers": []} + + server_id = service.get_transfer_server_id() + + assert server_id == "" + + +def test_get_transfer_server_id_returns_empty_on_generic_error( + service, mock_transfer_client +): + mock_transfer_client.list_servers.side_effect = Exception("boom") + + server_id = service.get_transfer_server_id() + + assert server_id == "" + +def test_handle_scan_message_calls_stop_server_for_infected_expedite( + service, sns_event, mocker +): + message = json.loads(sns_event["Records"][0]["Sns"]["Message"]) + server_id = "srv-abc" + + mock_stop = mocker.patch.object( + service, + "stop_transfer_family_server", + return_value=response("Server stopped"), + ) + + resp = service.handle_scan_message(server_id=server_id, message=message) + + mock_stop.assert_called_once_with(server_id) + assert extract_message(resp) == "Server stopped" + +def test_is_relevant_scan_result_true_for_expected_values(service): + for value in EXPECTED_SCAN_RESULTS: + assert service.is_relevant_scan_result(value) is True + + +def test_is_relevant_scan_result_false_for_other_values(service): + assert service.is_relevant_scan_result("CLEAN") is False + assert service.is_relevant_scan_result("") is False + assert service.is_relevant_scan_result(None) is False + +def test_has_required_fields_true_when_bucket_and_key_present(service): + assert service.has_required_fields("bucket", "key") is True + + +def test_has_required_fields_false_when_bucket_or_key_missing(service): + assert service.has_required_fields("", "key") is False + assert service.has_required_fields("bucket", "") is False + assert service.has_required_fields(None, "key") is False + assert service.has_required_fields("bucket", None) is False + +def test_is_quarantine_expedite_true_for_valid_quarantine_key(service): + bucket = "cloudstoragesecquarantine-xyz" + key = "pre-prod-staging-bulk-store/expedite/path/file.pdf" + + assert service.is_quarantine_expedite(bucket, key) is True + + +def test_is_quarantine_expedite_false_for_non_quarantine_bucket(service): + bucket = "some-other-bucket" + key = "pre-prod-staging-bulk-store/expedite/path/file.pdf" + + assert service.is_quarantine_expedite(bucket, key) is False + + +def test_is_quarantine_expedite_false_if_staging_bucket_not_set( + mock_transfer_client, monkeypatch +): + monkeypatch.delenv("STAGING_STORE_BUCKET_NAME", raising=False) + monkeypatch.setenv("WORKSPACE", "pre-prod") + + service = ExpediteKillSwitchService() + + bucket = "cloudstoragesecquarantine-xyz" + key = "pre-prod-staging-bulk-store/expedite/path/file.pdf" + + assert service.is_quarantine_expedite(bucket, key) is False + + +def test_extract_sns_message_parses_valid_event(service, sns_event): + msg = service.extract_sns_message(sns_event) + + assert isinstance(msg, dict) + assert msg["scanResult"] == "Infected" + assert msg["bucket"].startswith("cloudstoragesecquarantine-") + assert msg["key"].startswith("pre-prod-staging-bulk-store/expedite/") + + +def test_extract_sns_message_returns_none_for_invalid_shapes(service): + assert service.extract_sns_message({}) is None + assert service.extract_sns_message({"Records": []}) is None + assert service.extract_sns_message({"Records": [{}]}) is None + assert service.extract_sns_message({"Records": [{"Sns": {}}]}) is None + +def test_stop_transfer_family_server_happy_path_stops_server( + service, mock_transfer_client +): + mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}} + + resp = service.stop_transfer_family_server("srv-abc") + + mock_transfer_client.describe_server.assert_called_once_with(ServerId="srv-abc") + mock_transfer_client.stop_server.assert_called_once_with(ServerId="srv-abc") + assert extract_message(resp) == "Server srv-abc stopped" + + +def test_stop_transfer_family_server_returns_not_found_if_server_missing( + service, mock_transfer_client +): + NotFound = mock_transfer_client.exceptions.ResourceNotFoundException + mock_transfer_client.describe_server.side_effect = NotFound() + + resp = service.stop_transfer_family_server("srv-missing") + + mock_transfer_client.stop_server.assert_not_called() + assert extract_message(resp) == "Server not found" + + +def test_stop_transfer_family_server_handles_generic_exception( + service, mock_transfer_client +): + mock_transfer_client.describe_server.side_effect = Exception("boom") + + resp = service.stop_transfer_family_server("srv-error") + + mock_transfer_client.stop_server.assert_not_called() + assert extract_message(resp) == "Failed to stop server" + +def test_handle_scan_message_ignores_irrelevant_scan_result(service, mocker): + message = { + "scanResult": "Clean", + "bucket": "cloudstoragesecquarantine-abc", + "key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf", + } + + mock_stop = mocker.patch.object( + service, "stop_transfer_family_server", autospec=True + ) + + resp = service.handle_scan_message(server_id="srv-abc", message=message) + + assert extract_message(resp) == "Scan result not relevant, no action taken" + mock_stop.assert_not_called() + + +def test_handle_scan_message_returns_invalid_payload_when_bucket_missing(service, mocker): + message = { + "scanResult": "Infected", + "key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf", + } + + mock_stop = mocker.patch.object( + service, "stop_transfer_family_server", autospec=True + ) + + resp = service.handle_scan_message(server_id="srv-abc", message=message) + + assert extract_message(resp) == "Invalid payload (missing bucket/key)" + mock_stop.assert_not_called() + + +def test_handle_scan_message_returns_invalid_payload_when_key_missing(service, mocker): + message = { + "scanResult": "Infected", + "bucket": "cloudstoragesecquarantine-abc", + } + + mock_stop = mocker.patch.object( + service, "stop_transfer_family_server", autospec=True + ) + + resp = service.handle_scan_message(server_id="srv-abc", message=message) + + assert extract_message(resp) == "Invalid payload (missing bucket/key)" + mock_stop.assert_not_called() + + +def test_handle_scan_message_not_quarantine_expedite(service, mocker): + message = { + "scanResult": "Infected", + "bucket": "some-other-bucket", + "key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf", + } + + mock_stop = mocker.patch.object( + service, "stop_transfer_family_server", autospec=True + ) + + resp = service.handle_scan_message(server_id="srv-abc", message=message) + + assert extract_message(resp) == "Not an expedite file, no action taken" + mock_stop.assert_not_called() + + +def test_handle_scan_message_non_infected_expedite(service, mocker): + message = { + "scanResult": "Error", + "bucket": "cloudstoragesecquarantine-abc", + "key": "pre-prod-staging-bulk-store/expedite/folder/file.pdf", + } + + mock_stop = mocker.patch.object( + service, "stop_transfer_family_server", autospec=True + ) + + resp = service.handle_scan_message(server_id="srv-abc", message=message) + + assert ( + extract_message(resp) + == "Non-infected result for expedite file, no kill switch action" + ) + mock_stop.assert_not_called() + +def test_extract_sns_message_returns_none_on_invalid_json(service): + event = { + "Records": [ + { + "Sns": { + "Message": "not-json-at-all", + } + } + ] + } + + msg = service.extract_sns_message(event) + + assert msg is None + +def test_stop_transfer_family_server_handles_metric_failure( + service, mock_transfer_client, mocker +): + mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}} + + mocker.patch.object( + service, + "report_kill_switch_activated", + side_effect=Exception("metric failed"), + ) + + resp = service.stop_transfer_family_server("srv-xyz") + + mock_transfer_client.describe_server.assert_called_once_with(ServerId="srv-xyz") + mock_transfer_client.stop_server.assert_called_once_with(ServerId="srv-xyz") + assert ( + extract_message(resp) + == "Server srv-xyz stopped, but failed to alert the team" + )