diff --git a/lambdas/services/expedite_transfer_family_kill_switch_service.py b/lambdas/services/expedite_transfer_family_kill_switch_service.py index b3998d2bd..672156e73 100644 --- a/lambdas/services/expedite_transfer_family_kill_switch_service.py +++ b/lambdas/services/expedite_transfer_family_kill_switch_service.py @@ -8,6 +8,8 @@ logger = LoggingService(__name__) EXPECTED_SCAN_RESULTS = {"Infected", "Error", "Unscannable", "Suspicious"} +STOP_WORTHY_SCAN_RESULTS = {"Infected", "Unscannable", "Suspicious"} + def response(message: str): @@ -29,29 +31,22 @@ def handle_sns_event(self, event: dict): server_id = self.get_transfer_server_id() if not server_id: - logger.warning( - "No Transfer Family server ID resolved from AWS – kill switch disabled." - ) - return { - "statusCode": 200, - "body": json.dumps( - { - "message": ( - "Transfer family kill switch disabled – no Transfer server ID discovered" - ) - } - ), - } + logger.warning("No Transfer Family server ID resolved from AWS – kill switch disabled.") + return response("Transfer family kill switch disabled – no Transfer server ID discovered") - logger.warning( - "Initiating Transfer Family shutdown.", - { - "server_id": server_id, - "workspace": self.workspace, - }, - ) + message = self.extract_sns_message(event) + if not message: + logger.error("Unable to parse SNS message JSON; not taking action") + return response("Invalid SNS message; no action taken") - return self.stop_transfer_family_server(server_id) + scan_result = message.get("scanResult") + + if scan_result in STOP_WORTHY_SCAN_RESULTS: + logger.warning("Stopping Transfer Family due to scan result", {"scanResult": scan_result}) + return self.stop_transfer_family_server(server_id) + + logger.info("Scan result not actionable; no kill switch action", {"scanResult": scan_result}) + return response("No action taken") def handle_scan_message(self, server_id: str, message: dict): scan_result = message.get("scanResult") diff --git a/lambdas/tests/unit/services/test_expedite_kill_switch_service.py b/lambdas/tests/unit/services/test_expedite_kill_switch_service.py index bed4631ea..c68011d24 100644 --- a/lambdas/tests/unit/services/test_expedite_kill_switch_service.py +++ b/lambdas/tests/unit/services/test_expedite_kill_switch_service.py @@ -4,6 +4,7 @@ from services.expedite_transfer_family_kill_switch_service import ( ExpediteKillSwitchService, EXPECTED_SCAN_RESULTS, + STOP_WORTHY_SCAN_RESULTS, response, ) @@ -12,6 +13,7 @@ def mock_transfer_client(mocker): transfer_client = mocker.Mock() cloudwatch_client = mocker.Mock() + class ResourceNotFoundException(Exception): pass @@ -61,6 +63,7 @@ def sns_event(): def extract_message(resp): return json.loads(resp["body"])["message"] + def test_response_builds_expected_http_shape(): msg = "hello world" resp = response(msg) @@ -71,7 +74,7 @@ def test_response_builds_expected_http_shape(): def test_handle_sns_event_happy_path_infected_expedite( - service, sns_event, mock_transfer_client + service, sns_event, mock_transfer_client ): mock_transfer_client.list_servers.return_value = { "Servers": [{"ServerId": "srv-12345"}] @@ -87,22 +90,22 @@ def test_handle_sns_event_happy_path_infected_expedite( def test_handle_sns_event_no_servers_disables_kill_switch( - service, sns_event, mock_transfer_client + service, sns_event, mock_transfer_client ): mock_transfer_client.list_servers.return_value = {"Servers": []} resp = service.handle_sns_event(sns_event) assert ( - extract_message(resp) - == "Transfer family kill switch disabled – no Transfer server ID discovered" + extract_message(resp) + == "Transfer family kill switch disabled – no Transfer server ID discovered" ) mock_transfer_client.describe_server.assert_not_called() mock_transfer_client.stop_server.assert_not_called() def test_get_transfer_server_id_happy_path_reads_from_list_servers( - service, mock_transfer_client + service, mock_transfer_client ): mock_transfer_client.list_servers.return_value = { "Servers": [{"ServerId": " srv-9999 "}] @@ -115,7 +118,7 @@ def test_get_transfer_server_id_happy_path_reads_from_list_servers( def test_get_transfer_server_id_returns_empty_when_no_servers( - service, mock_transfer_client + service, mock_transfer_client ): mock_transfer_client.list_servers.return_value = {"Servers": []} @@ -125,7 +128,7 @@ def test_get_transfer_server_id_returns_empty_when_no_servers( def test_get_transfer_server_id_returns_empty_on_generic_error( - service, mock_transfer_client + service, mock_transfer_client ): mock_transfer_client.list_servers.side_effect = Exception("boom") @@ -133,8 +136,9 @@ def test_get_transfer_server_id_returns_empty_on_generic_error( assert server_id == "" + def test_handle_scan_message_calls_stop_server_for_infected_expedite( - service, sns_event, mocker + service, sns_event, mocker ): message = json.loads(sns_event["Records"][0]["Sns"]["Message"]) server_id = "srv-abc" @@ -150,6 +154,7 @@ def test_handle_scan_message_calls_stop_server_for_infected_expedite( mock_stop.assert_called_once_with(server_id) assert extract_message(resp) == "Server stopped" + def test_is_relevant_scan_result_true_for_expected_values(service): for value in EXPECTED_SCAN_RESULTS: assert service.is_relevant_scan_result(value) is True @@ -160,6 +165,7 @@ def test_is_relevant_scan_result_false_for_other_values(service): assert service.is_relevant_scan_result("") is False assert service.is_relevant_scan_result(None) is False + def test_has_required_fields_true_when_bucket_and_key_present(service): assert service.has_required_fields("bucket", "key") is True @@ -170,6 +176,7 @@ def test_has_required_fields_false_when_bucket_or_key_missing(service): assert service.has_required_fields(None, "key") is False assert service.has_required_fields("bucket", None) is False + def test_is_quarantine_expedite_true_for_valid_quarantine_key(service): bucket = "cloudstoragesecquarantine-xyz" key = "pre-prod-staging-bulk-store/expedite/path/file.pdf" @@ -185,7 +192,7 @@ def test_is_quarantine_expedite_false_for_non_quarantine_bucket(service): def test_is_quarantine_expedite_false_if_staging_bucket_not_set( - mock_transfer_client, monkeypatch + mock_transfer_client, monkeypatch ): monkeypatch.delenv("STAGING_STORE_BUCKET_NAME", raising=False) monkeypatch.setenv("WORKSPACE", "pre-prod") @@ -213,8 +220,9 @@ def test_extract_sns_message_returns_none_for_invalid_shapes(service): assert service.extract_sns_message({"Records": [{}]}) is None assert service.extract_sns_message({"Records": [{"Sns": {}}]}) is None + def test_stop_transfer_family_server_happy_path_stops_server( - service, mock_transfer_client + service, mock_transfer_client ): mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}} @@ -226,7 +234,7 @@ def test_stop_transfer_family_server_happy_path_stops_server( def test_stop_transfer_family_server_returns_not_found_if_server_missing( - service, mock_transfer_client + service, mock_transfer_client ): NotFound = mock_transfer_client.exceptions.ResourceNotFoundException mock_transfer_client.describe_server.side_effect = NotFound() @@ -238,7 +246,7 @@ def test_stop_transfer_family_server_returns_not_found_if_server_missing( def test_stop_transfer_family_server_handles_generic_exception( - service, mock_transfer_client + service, mock_transfer_client ): mock_transfer_client.describe_server.side_effect = Exception("boom") @@ -247,6 +255,7 @@ def test_stop_transfer_family_server_handles_generic_exception( mock_transfer_client.stop_server.assert_not_called() assert extract_message(resp) == "Failed to stop server" + def test_handle_scan_message_ignores_irrelevant_scan_result(service, mocker): message = { "scanResult": "Clean", @@ -327,11 +336,12 @@ def test_handle_scan_message_non_infected_expedite(service, mocker): resp = service.handle_scan_message(server_id="srv-abc", message=message) assert ( - extract_message(resp) - == "Non-infected result for expedite file, no kill switch action" + extract_message(resp) + == "Non-infected result for expedite file, no kill switch action" ) mock_stop.assert_not_called() + def test_extract_sns_message_returns_none_on_invalid_json(service): event = { "Records": [ @@ -347,8 +357,9 @@ def test_extract_sns_message_returns_none_on_invalid_json(service): assert msg is None + def test_stop_transfer_family_server_handles_metric_failure( - service, mock_transfer_client, mocker + service, mock_transfer_client, mocker ): mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}} @@ -362,7 +373,155 @@ def test_stop_transfer_family_server_handles_metric_failure( mock_transfer_client.describe_server.assert_called_once_with(ServerId="srv-xyz") mock_transfer_client.stop_server.assert_called_once_with(ServerId="srv-xyz") + assert ( + extract_message(resp) + == "Server srv-xyz stopped, but failed to alert the team" + ) + +def test_handle_sns_event_stops_for_each_stop_worthy_scan_result( + service, sns_event, mock_transfer_client, monkeypatch +): + mock_transfer_client.list_servers.return_value = { + "Servers": [{"ServerId": "srv-12345"}] + } + mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}} + + for scan_result in STOP_WORTHY_SCAN_RESULTS: + message = json.loads(sns_event["Records"][0]["Sns"]["Message"]) + message["scanResult"] = scan_result + + event = { + "Records": [ + { + "Sns": { + "Message": json.dumps(message), + } + } + ] + } + + resp = service.handle_sns_event(event) + + assert extract_message(resp) == "Server srv-12345 stopped" + + assert mock_transfer_client.stop_server.call_count == len(STOP_WORTHY_SCAN_RESULTS) + + +def test_handle_sns_event_does_not_stop_for_error_scan_result( + service, sns_event, mock_transfer_client +): + mock_transfer_client.list_servers.return_value = { + "Servers": [{"ServerId": "srv-12345"}] + } + + message = json.loads(sns_event["Records"][0]["Sns"]["Message"]) + message["scanResult"] = "Error" + event = { + "Records": [ + { + "Sns": { + "Message": json.dumps(message), + } + } + ] + } + + resp = service.handle_sns_event(event) + + assert extract_message(resp) == "No action taken" + mock_transfer_client.describe_server.assert_not_called() + mock_transfer_client.stop_server.assert_not_called() + + +def test_handle_sns_event_does_not_stop_for_clean_or_unknown_scan_result( + service, sns_event, mock_transfer_client +): + mock_transfer_client.list_servers.return_value = { + "Servers": [{"ServerId": "srv-12345"}] + } + + for scan_result in ["Clean", "CLEAN", "Unknown", "NoThreatsFound"]: + message = json.loads(sns_event["Records"][0]["Sns"]["Message"]) + message["scanResult"] = scan_result + + event = { + "Records": [ + { + "Sns": { + "Message": json.dumps(message), + } + } + ] + } + + resp = service.handle_sns_event(event) + assert extract_message(resp) == "No action taken" + + mock_transfer_client.describe_server.assert_not_called() + mock_transfer_client.stop_server.assert_not_called() + + +def test_handle_sns_event_returns_no_action_when_scan_result_missing( + service, sns_event, mock_transfer_client +): + mock_transfer_client.list_servers.return_value = { + "Servers": [{"ServerId": "srv-12345"}] + } + + message = json.loads(sns_event["Records"][0]["Sns"]["Message"]) + message.pop("scanResult", None) + + event = { + "Records": [ + { + "Sns": { + "Message": json.dumps(message), + } + } + ] + } + + resp = service.handle_sns_event(event) + + assert extract_message(resp) == "No action taken" + mock_transfer_client.describe_server.assert_not_called() + mock_transfer_client.stop_server.assert_not_called() + + +def test_handle_sns_event_returns_invalid_sns_message_on_bad_json( + service, mock_transfer_client +): + mock_transfer_client.list_servers.return_value = { + "Servers": [{"ServerId": "srv-12345"}] + } + + event = { + "Records": [ + { + "Sns": { + "Message": "not-json-at-all", + } + } + ] + } + + resp = service.handle_sns_event(event) + + assert extract_message(resp) == "Invalid SNS message; no action taken" + mock_transfer_client.describe_server.assert_not_called() + mock_transfer_client.stop_server.assert_not_called() + + +def test_handle_sns_event_no_servers_message_contract( + service, sns_event, mock_transfer_client +): + mock_transfer_client.list_servers.return_value = {"Servers": []} + + resp = service.handle_sns_event(sns_event) + assert ( extract_message(resp) - == "Server srv-xyz stopped, but failed to alert the team" + == "Transfer family kill switch disabled – no Transfer server ID discovered" ) + mock_transfer_client.describe_server.assert_not_called() + mock_transfer_client.stop_server.assert_not_called()