Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions lambdas/services/expedite_transfer_family_kill_switch_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
logger = LoggingService(__name__)

EXPECTED_SCAN_RESULTS = {"Infected", "Error", "Unscannable", "Suspicious"}
STOP_WORTHY_SCAN_RESULTS = {"Infected", "Unscannable", "Suspicious"}



def response(message: str):
Expand All @@ -29,29 +31,22 @@ def handle_sns_event(self, event: dict):

server_id = self.get_transfer_server_id()
if not server_id:
logger.warning(
"No Transfer Family server ID resolved from AWS – kill switch disabled."
)
return {
"statusCode": 200,
"body": json.dumps(
{
"message": (
"Transfer family kill switch disabled – no Transfer server ID discovered"
)
}
),
}
logger.warning("No Transfer Family server ID resolved from AWS – kill switch disabled.")
return response("Transfer family kill switch disabled – no Transfer server ID discovered")

logger.warning(
"Initiating Transfer Family shutdown.",
{
"server_id": server_id,
"workspace": self.workspace,
},
)
message = self.extract_sns_message(event)
if not message:
logger.error("Unable to parse SNS message JSON; not taking action")
return response("Invalid SNS message; no action taken")

return self.stop_transfer_family_server(server_id)
scan_result = message.get("scanResult")

if scan_result in STOP_WORTHY_SCAN_RESULTS:
logger.warning("Stopping Transfer Family due to scan result", {"scanResult": scan_result})
return self.stop_transfer_family_server(server_id)

logger.info("Scan result not actionable; no kill switch action", {"scanResult": scan_result})
return response("No action taken")

def handle_scan_message(self, server_id: str, message: dict):
scan_result = message.get("scanResult")
Expand Down
191 changes: 175 additions & 16 deletions lambdas/tests/unit/services/test_expedite_kill_switch_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from services.expedite_transfer_family_kill_switch_service import (
ExpediteKillSwitchService,
EXPECTED_SCAN_RESULTS,
STOP_WORTHY_SCAN_RESULTS,
response,
)

Expand All @@ -12,6 +13,7 @@
def mock_transfer_client(mocker):
transfer_client = mocker.Mock()
cloudwatch_client = mocker.Mock()

class ResourceNotFoundException(Exception):
pass

Expand Down Expand Up @@ -61,6 +63,7 @@ def sns_event():
def extract_message(resp):
return json.loads(resp["body"])["message"]


def test_response_builds_expected_http_shape():
msg = "hello world"
resp = response(msg)
Expand All @@ -71,7 +74,7 @@ def test_response_builds_expected_http_shape():


def test_handle_sns_event_happy_path_infected_expedite(
service, sns_event, mock_transfer_client
service, sns_event, mock_transfer_client
):
mock_transfer_client.list_servers.return_value = {
"Servers": [{"ServerId": "srv-12345"}]
Expand All @@ -87,22 +90,22 @@ def test_handle_sns_event_happy_path_infected_expedite(


def test_handle_sns_event_no_servers_disables_kill_switch(
service, sns_event, mock_transfer_client
service, sns_event, mock_transfer_client
):
mock_transfer_client.list_servers.return_value = {"Servers": []}

resp = service.handle_sns_event(sns_event)

assert (
extract_message(resp)
== "Transfer family kill switch disabled – no Transfer server ID discovered"
extract_message(resp)
== "Transfer family kill switch disabled – no Transfer server ID discovered"
)
mock_transfer_client.describe_server.assert_not_called()
mock_transfer_client.stop_server.assert_not_called()


def test_get_transfer_server_id_happy_path_reads_from_list_servers(
service, mock_transfer_client
service, mock_transfer_client
):
mock_transfer_client.list_servers.return_value = {
"Servers": [{"ServerId": " srv-9999 "}]
Expand All @@ -115,7 +118,7 @@ def test_get_transfer_server_id_happy_path_reads_from_list_servers(


def test_get_transfer_server_id_returns_empty_when_no_servers(
service, mock_transfer_client
service, mock_transfer_client
):
mock_transfer_client.list_servers.return_value = {"Servers": []}

Expand All @@ -125,16 +128,17 @@ def test_get_transfer_server_id_returns_empty_when_no_servers(


def test_get_transfer_server_id_returns_empty_on_generic_error(
service, mock_transfer_client
service, mock_transfer_client
):
mock_transfer_client.list_servers.side_effect = Exception("boom")

server_id = service.get_transfer_server_id()

assert server_id == ""


def test_handle_scan_message_calls_stop_server_for_infected_expedite(
service, sns_event, mocker
service, sns_event, mocker
):
message = json.loads(sns_event["Records"][0]["Sns"]["Message"])
server_id = "srv-abc"
Expand All @@ -150,6 +154,7 @@ def test_handle_scan_message_calls_stop_server_for_infected_expedite(
mock_stop.assert_called_once_with(server_id)
assert extract_message(resp) == "Server stopped"


def test_is_relevant_scan_result_true_for_expected_values(service):
for value in EXPECTED_SCAN_RESULTS:
assert service.is_relevant_scan_result(value) is True
Expand All @@ -160,6 +165,7 @@ def test_is_relevant_scan_result_false_for_other_values(service):
assert service.is_relevant_scan_result("") is False
assert service.is_relevant_scan_result(None) is False


def test_has_required_fields_true_when_bucket_and_key_present(service):
assert service.has_required_fields("bucket", "key") is True

Expand All @@ -170,6 +176,7 @@ def test_has_required_fields_false_when_bucket_or_key_missing(service):
assert service.has_required_fields(None, "key") is False
assert service.has_required_fields("bucket", None) is False


def test_is_quarantine_expedite_true_for_valid_quarantine_key(service):
bucket = "cloudstoragesecquarantine-xyz"
key = "pre-prod-staging-bulk-store/expedite/path/file.pdf"
Expand All @@ -185,7 +192,7 @@ def test_is_quarantine_expedite_false_for_non_quarantine_bucket(service):


def test_is_quarantine_expedite_false_if_staging_bucket_not_set(
mock_transfer_client, monkeypatch
mock_transfer_client, monkeypatch
):
monkeypatch.delenv("STAGING_STORE_BUCKET_NAME", raising=False)
monkeypatch.setenv("WORKSPACE", "pre-prod")
Expand Down Expand Up @@ -213,8 +220,9 @@ def test_extract_sns_message_returns_none_for_invalid_shapes(service):
assert service.extract_sns_message({"Records": [{}]}) is None
assert service.extract_sns_message({"Records": [{"Sns": {}}]}) is None


def test_stop_transfer_family_server_happy_path_stops_server(
service, mock_transfer_client
service, mock_transfer_client
):
mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}}

Expand All @@ -226,7 +234,7 @@ def test_stop_transfer_family_server_happy_path_stops_server(


def test_stop_transfer_family_server_returns_not_found_if_server_missing(
service, mock_transfer_client
service, mock_transfer_client
):
NotFound = mock_transfer_client.exceptions.ResourceNotFoundException
mock_transfer_client.describe_server.side_effect = NotFound()
Expand All @@ -238,7 +246,7 @@ def test_stop_transfer_family_server_returns_not_found_if_server_missing(


def test_stop_transfer_family_server_handles_generic_exception(
service, mock_transfer_client
service, mock_transfer_client
):
mock_transfer_client.describe_server.side_effect = Exception("boom")

Expand All @@ -247,6 +255,7 @@ def test_stop_transfer_family_server_handles_generic_exception(
mock_transfer_client.stop_server.assert_not_called()
assert extract_message(resp) == "Failed to stop server"


def test_handle_scan_message_ignores_irrelevant_scan_result(service, mocker):
message = {
"scanResult": "Clean",
Expand Down Expand Up @@ -327,11 +336,12 @@ def test_handle_scan_message_non_infected_expedite(service, mocker):
resp = service.handle_scan_message(server_id="srv-abc", message=message)

assert (
extract_message(resp)
== "Non-infected result for expedite file, no kill switch action"
extract_message(resp)
== "Non-infected result for expedite file, no kill switch action"
)
mock_stop.assert_not_called()


def test_extract_sns_message_returns_none_on_invalid_json(service):
event = {
"Records": [
Expand All @@ -347,8 +357,9 @@ def test_extract_sns_message_returns_none_on_invalid_json(service):

assert msg is None


def test_stop_transfer_family_server_handles_metric_failure(
service, mock_transfer_client, mocker
service, mock_transfer_client, mocker
):
mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}}

Expand All @@ -362,7 +373,155 @@ def test_stop_transfer_family_server_handles_metric_failure(

mock_transfer_client.describe_server.assert_called_once_with(ServerId="srv-xyz")
mock_transfer_client.stop_server.assert_called_once_with(ServerId="srv-xyz")
assert (
extract_message(resp)
== "Server srv-xyz stopped, but failed to alert the team"
)

def test_handle_sns_event_stops_for_each_stop_worthy_scan_result(
service, sns_event, mock_transfer_client, monkeypatch
):
mock_transfer_client.list_servers.return_value = {
"Servers": [{"ServerId": "srv-12345"}]
}
mock_transfer_client.describe_server.return_value = {"Server": {"State": "ONLINE"}}

for scan_result in STOP_WORTHY_SCAN_RESULTS:
message = json.loads(sns_event["Records"][0]["Sns"]["Message"])
message["scanResult"] = scan_result

event = {
"Records": [
{
"Sns": {
"Message": json.dumps(message),
}
}
]
}

resp = service.handle_sns_event(event)

assert extract_message(resp) == "Server srv-12345 stopped"

assert mock_transfer_client.stop_server.call_count == len(STOP_WORTHY_SCAN_RESULTS)


def test_handle_sns_event_does_not_stop_for_error_scan_result(
service, sns_event, mock_transfer_client
):
mock_transfer_client.list_servers.return_value = {
"Servers": [{"ServerId": "srv-12345"}]
}

message = json.loads(sns_event["Records"][0]["Sns"]["Message"])
message["scanResult"] = "Error"
event = {
"Records": [
{
"Sns": {
"Message": json.dumps(message),
}
}
]
}

resp = service.handle_sns_event(event)

assert extract_message(resp) == "No action taken"
mock_transfer_client.describe_server.assert_not_called()
mock_transfer_client.stop_server.assert_not_called()


def test_handle_sns_event_does_not_stop_for_clean_or_unknown_scan_result(
service, sns_event, mock_transfer_client
):
mock_transfer_client.list_servers.return_value = {
"Servers": [{"ServerId": "srv-12345"}]
}

for scan_result in ["Clean", "CLEAN", "Unknown", "NoThreatsFound"]:
message = json.loads(sns_event["Records"][0]["Sns"]["Message"])
message["scanResult"] = scan_result

event = {
"Records": [
{
"Sns": {
"Message": json.dumps(message),
}
}
]
}

resp = service.handle_sns_event(event)
assert extract_message(resp) == "No action taken"

mock_transfer_client.describe_server.assert_not_called()
mock_transfer_client.stop_server.assert_not_called()


def test_handle_sns_event_returns_no_action_when_scan_result_missing(
service, sns_event, mock_transfer_client
):
mock_transfer_client.list_servers.return_value = {
"Servers": [{"ServerId": "srv-12345"}]
}

message = json.loads(sns_event["Records"][0]["Sns"]["Message"])
message.pop("scanResult", None)

event = {
"Records": [
{
"Sns": {
"Message": json.dumps(message),
}
}
]
}

resp = service.handle_sns_event(event)

assert extract_message(resp) == "No action taken"
mock_transfer_client.describe_server.assert_not_called()
mock_transfer_client.stop_server.assert_not_called()


def test_handle_sns_event_returns_invalid_sns_message_on_bad_json(
service, mock_transfer_client
):
mock_transfer_client.list_servers.return_value = {
"Servers": [{"ServerId": "srv-12345"}]
}

event = {
"Records": [
{
"Sns": {
"Message": "not-json-at-all",
}
}
]
}

resp = service.handle_sns_event(event)

assert extract_message(resp) == "Invalid SNS message; no action taken"
mock_transfer_client.describe_server.assert_not_called()
mock_transfer_client.stop_server.assert_not_called()


def test_handle_sns_event_no_servers_message_contract(
service, sns_event, mock_transfer_client
):
mock_transfer_client.list_servers.return_value = {"Servers": []}

resp = service.handle_sns_event(sns_event)

assert (
extract_message(resp)
== "Server srv-xyz stopped, but failed to alert the team"
== "Transfer family kill switch disabled – no Transfer server ID discovered"
)
mock_transfer_client.describe_server.assert_not_called()
mock_transfer_client.stop_server.assert_not_called()
Loading