Skip to content

Commit 286402e

Browse files
[PRMP 862] Implement an AWS Transfer Family kill switch (#901)
1 parent ce7a52c commit 286402e

File tree

6 files changed

+658
-1
lines changed

6 files changed

+658
-1
lines changed

.github/workflows/base-lambdas-reusable-deploy-all.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,20 @@ jobs:
697697
secrets:
698698
AWS_ASSUME_ROLE: ${{ secrets.AWS_ASSUME_ROLE }}
699699

700+
deploy_transfer_kill_switch_lambda:
701+
name: Deploy transfer kill switch lambda
702+
uses: ./.github/workflows/base-lambdas-reusable-deploy.yml
703+
with:
704+
environment: ${{ inputs.environment }}
705+
python_version: ${{ inputs.python_version }}
706+
build_branch: ${{ inputs.build_branch }}
707+
sandbox: ${{ inputs.sandbox }}
708+
lambda_handler_name: transfer_family_kill_switch_handler
709+
lambda_aws_name: TransferFamilyKillSwitch
710+
lambda_layer_names: "core_lambda_layer"
711+
secrets:
712+
AWS_ASSUME_ROLE: ${{ secrets.AWS_ASSUME_ROLE }}
713+
700714
deploy_search_document_review_lambda:
701715
name: Deploy Search Document Review
702716
uses: ./.github/workflows/base-lambdas-reusable-deploy.yml
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from services.expedite_transfer_family_kill_switch_service import ExpediteKillSwitchService
2+
from utils.decorators.handle_lambda_exceptions import handle_lambda_exceptions
3+
from utils.decorators.set_audit_arg import set_request_context_for_logging
4+
5+
6+
@handle_lambda_exceptions
7+
@set_request_context_for_logging
8+
def lambda_handler(event, context):
9+
service = ExpediteKillSwitchService()
10+
return service.handle_sns_event(event)
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
import json
2+
import os
3+
4+
import boto3
5+
6+
from utils.audit_logging_setup import LoggingService
7+
8+
logger = LoggingService(__name__)
9+
10+
EXPECTED_SCAN_RESULTS = {"Infected", "Error", "Unscannable", "Suspicious"}
11+
12+
13+
def response(message: str):
14+
return {
15+
"statusCode": 200,
16+
"body": json.dumps({"message": message}),
17+
}
18+
19+
class ExpediteKillSwitchService:
20+
def __init__(self):
21+
self.transfer_client = boto3.client("transfer")
22+
self.cloudwatch = boto3.client("cloudwatch")
23+
24+
self.staging_bucket = os.environ.get("STAGING_STORE_BUCKET_NAME", "")
25+
self.workspace = os.environ.get("WORKSPACE", "")
26+
27+
def handle_sns_event(self, event: dict):
28+
logger.info("Received SNS virus scan notification event", {"event": event})
29+
30+
server_id = self.get_transfer_server_id()
31+
if not server_id:
32+
logger.warning(
33+
"No Transfer Family server ID resolved from AWS – kill switch disabled."
34+
)
35+
return {
36+
"statusCode": 200,
37+
"body": json.dumps(
38+
{
39+
"message": (
40+
"Transfer family kill switch disabled – no Transfer server ID discovered"
41+
)
42+
}
43+
),
44+
}
45+
46+
logger.warning(
47+
"Initiating Transfer Family shutdown.",
48+
{
49+
"server_id": server_id,
50+
"workspace": self.workspace,
51+
},
52+
)
53+
54+
return self.stop_transfer_family_server(server_id)
55+
56+
def handle_scan_message(self, server_id: str, message: dict):
57+
scan_result = message.get("scanResult")
58+
bucket = message.get("bucket")
59+
key = message.get("key")
60+
61+
if not self.is_relevant_scan_result(scan_result):
62+
logger.info(
63+
f"Ignoring scan result '{scan_result}' – not one of {EXPECTED_SCAN_RESULTS}"
64+
)
65+
return response("Scan result not relevant, no action taken")
66+
67+
if not self.has_required_fields(bucket, key):
68+
logger.error("SNS payload missing required 'bucket' or 'key' fields")
69+
return response("Invalid payload (missing bucket/key)")
70+
71+
if not self.is_quarantine_expedite(bucket, key):
72+
logger.info(
73+
"Scan notification is not for an expedite file – no kill switch action",
74+
{
75+
"bucket": bucket,
76+
"key": key,
77+
"staging_bucket": self.staging_bucket,
78+
"workspace": self.workspace,
79+
},
80+
)
81+
return response("Not an expedite file, no action taken")
82+
83+
if scan_result != "Infected":
84+
logger.warning(
85+
"Non-clean scan result for expedite file, but not 'Infected' – no kill switch action",
86+
{
87+
"scanResult": scan_result,
88+
"bucket": bucket,
89+
"key": key,
90+
"workspace": self.workspace,
91+
},
92+
)
93+
return response(
94+
"Non-infected result for expedite file, no kill switch action"
95+
)
96+
97+
logger.warning(
98+
"Initiating Transfer Family shutdown.",
99+
{
100+
"server_id": server_id,
101+
"bucket": bucket,
102+
"key": key,
103+
"scanResult": scan_result,
104+
"workspace": self.workspace,
105+
},
106+
)
107+
108+
return self.stop_transfer_family_server(server_id)
109+
110+
def is_relevant_scan_result(self, scan_result: str) -> bool:
111+
return scan_result in EXPECTED_SCAN_RESULTS
112+
113+
def has_required_fields(self, bucket: str, key: str) -> bool:
114+
return bool(bucket and key)
115+
116+
def is_quarantine_expedite(self, bucket: str, key: str) -> bool:
117+
"""
118+
Example quarantine:
119+
bucket = cloudstoragesecquarantine-...
120+
key = "pre-prod-staging-bulk-store/expedite/..."
121+
Where key starts with "<workspace>-staging-bulk-store/expedite/"
122+
"""
123+
if not self.staging_bucket:
124+
return False
125+
126+
quarantine_prefix = f"{self.staging_bucket}/expedite/"
127+
return (
128+
bucket.startswith("cloudstoragesecquarantine-")
129+
and key.startswith(quarantine_prefix)
130+
)
131+
132+
def get_transfer_server_id(self) -> str:
133+
"""
134+
Discover Transfer Family servers in this account/region and return
135+
the first ServerId, or "" if none exist or an error occurs.
136+
"""
137+
try:
138+
resp = self.transfer_client.list_servers(MaxResults=1)
139+
servers = resp.get("Servers", [])
140+
if not servers:
141+
logger.warning(
142+
"No AWS Transfer Family servers found in account/region "
143+
"– kill switch disabled."
144+
)
145+
return ""
146+
147+
server_id = servers[0]["ServerId"].strip()
148+
logger.info(
149+
"Resolved Transfer server ID via list_servers",
150+
{"server_id": server_id},
151+
)
152+
return server_id
153+
154+
except Exception as exc:
155+
logger.error(f"Failed to list Transfer Family servers: {exc}")
156+
return ""
157+
158+
def extract_sns_message(self, event):
159+
try:
160+
records = event.get("Records")
161+
if not records:
162+
return None
163+
164+
sns_record = records[0].get("Sns")
165+
if not sns_record:
166+
return None
167+
168+
raw_message = sns_record.get("Message")
169+
if not raw_message:
170+
return None
171+
172+
return json.loads(raw_message)
173+
174+
except Exception as exc:
175+
logger.error(f"Failed to parse SNS message: {exc}")
176+
return None
177+
178+
def stop_transfer_family_server(self, server_id: str):
179+
try:
180+
desc = self.transfer_client.describe_server(ServerId=server_id)
181+
logger.info(
182+
"Transfer Family server found",
183+
{"server_id": server_id, "state": desc["Server"]["State"]},
184+
)
185+
186+
self.transfer_client.stop_server(ServerId=server_id)
187+
logger.warning(
188+
f"Transfer Family server {server_id} STOPPED due to virus scan trigger"
189+
)
190+
try:
191+
self.report_kill_switch_activated(server_id=server_id)
192+
except Exception as metric_exc:
193+
logger.error(
194+
f"Failed to publish kill switch metric: {metric_exc},"
195+
f" leading to failing to inform that kill switch has been activated"
196+
)
197+
return response(
198+
f"Server {server_id} stopped, but failed to alert the team"
199+
)
200+
return response(f"Server {server_id} stopped")
201+
202+
except self.transfer_client.exceptions.ResourceNotFoundException:
203+
logger.error(f"Transfer Family server '{server_id}' not found")
204+
return response("Server not found")
205+
206+
except Exception as exc:
207+
logger.error(f"Failed to stop Transfer Family server: {exc}")
208+
return response("Failed to stop server")
209+
210+
def report_kill_switch_activated(self, server_id: str):
211+
try:
212+
self.cloudwatch.put_metric_data(
213+
Namespace="Custom/TransferFamilyKillSwitch",
214+
MetricData=[
215+
{
216+
"MetricName": "ServerStopped",
217+
"Dimensions": [
218+
{"Name": "Workspace", "Value": self.workspace or "unknown"},
219+
],
220+
"Value": 1.0,
221+
"Unit": "Count",
222+
}
223+
],
224+
)
225+
except Exception as metric_exc:
226+
logger.error(
227+
f"Failed to publish kill switch metric: {metric_exc},"
228+
f" leading to failing to inform that kill switch has been activated"
229+
)
230+
231+
logger.warning(
232+
f"Transfer Family server {server_id} STOPPED due to infected expedite upload"
233+
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
from handlers.transfer_family_kill_switch_handler import lambda_handler
3+
4+
5+
@pytest.fixture
6+
def mock_service(mocker):
7+
service_instance = mocker.Mock()
8+
mocker.patch(
9+
"handlers.transfer_family_kill_switch_handler.ExpediteKillSwitchService",
10+
return_value=service_instance,
11+
)
12+
return service_instance
13+
14+
15+
@pytest.fixture
16+
def context(mocker):
17+
context = mocker.Mock()
18+
context.aws_request_id = "test-request-id"
19+
return context
20+
21+
22+
def test_lambda_handler_delegates_to_service_handle_sns_event(mock_service, context):
23+
event = {"Records": []}
24+
expected_response = {"statusCode": 200, "body": '{"message": "ok"}'}
25+
26+
mock_service.handle_sns_event.return_value = expected_response
27+
28+
resp = lambda_handler(event, context)
29+
30+
mock_service.handle_sns_event.assert_called_once_with(event)
31+
assert resp == expected_response

lambdas/tests/unit/services/test_bulk_upload_metadata_processor_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
import pytest
1010
from botocore.exceptions import ClientError
11+
from freezegun import freeze_time
12+
1113
from enums.upload_status import UploadStatus
1214
from enums.virus_scan_result import VirusScanResult
13-
from freezegun import freeze_time
1415
from models.staging_metadata import (
1516
METADATA_FILENAME,
1617
BulkUploadQueueMetadata,

0 commit comments

Comments
 (0)