Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cftemplates/snapshots_tool_rds_source.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@
"Default": "NO",
"Description": "Set to the region where your RDS instances run, only if such region does not support Step Functions. Leave as NO otherwise"
},
"BackupKmsArn": {
"Type": "String",
"Default": "NO",
"AllowedPattern": "arn:aws:kms:*",
"Description": "Set ARN of kms key to copy snapshot with shared key (remember, destination account should be access it)"
},
"DeleteOldSnapshots": {
"Type": "String",
"Default": "TRUE",
Expand Down Expand Up @@ -341,7 +347,10 @@
},
"TAGGEDINSTANCE": {
"Ref": "TaggedInstance"
}
},
"BACKUP_KMS":{
"Ref": "BackupKmsArn"
}
}
},
"Role": {
Expand Down
45 changes: 45 additions & 0 deletions lambda/share_snapshots_rds/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
REGION = os.getenv('REGION_OVERRIDE').strip()
else:
REGION = os.getenv('AWS_DEFAULT_REGION')
BACKUP_KMS = os.getenv('BACKUP_KMS')

SUPPORTED_ENGINES = [ 'mariadb', 'sqlserver-se', 'sqlserver-ee', 'sqlserver-ex', 'sqlserver-web', 'mysql', 'oracle-se', 'oracle-se1', 'oracle-se2', 'oracle-ee', 'postgres' ]

Expand All @@ -38,6 +39,7 @@


def lambda_handler(event, context):
now = datetime.now()
pending_snapshots = 0
client = boto3.client('rds', region_name=REGION)
response = paginate_api_call(client, 'describe_db_snapshots', 'DBSnapshots', SnapshotType='manual')
Expand All @@ -50,8 +52,51 @@ def lambda_handler(event, context):
ResourceName=snapshot_arn)

if snapshot_object['Status'].lower() == 'available' and search_tag_shared(response_tags):

snapshot_info = client.describe_db_snapshots(
DBSnapshotIdentifier=snapshot_arn
)
timestamp_format = now.strftime('%Y-%m-%d-%H-%M')
targetSnapshot = snapshot_info['DBSnapshots'][0]['DBInstanceIdentifier'] + '-' + timestamp_format
logger.info('snapshot_info:{}'.format(snapshot_info))

if snapshot_info['DBSnapshots'][0]['Encrypted'] == True:
kms = get_kms_type(snapshot_info['DBSnapshots'][0]['KmsKeyId'],REGION)
else:
kms = False


logger.info('Checking Snapshot: {}'.format(snapshot_identifier))


if kms is True and BACKUP_KMS is not '':
try:
copy_status = client.copy_db_snapshot(
SourceDBSnapshotIdentifier=snapshot_arn,
TargetDBSnapshotIdentifier=targetSnapshot,
KmsKeyId=BACKUP_KMS,
CopyTags=True
)
pass
except Exception as e:
logger.error('Exception copy {}: {}'.format(snapshot_arn, e))
pending_snapshots += 1
pass
else:
modify_status = client.add_tags_to_resource(
ResourceName=snapshot_arn,
Tags=[
{
'Key': 'shareAndCopy',
'Value': 'No'
}
]
)


try:
# Share snapshot with dest_account
logger.info('Sharing snapshot: {}'.format(snapshot_identifier))
response_modify = client.modify_db_snapshot_attribute(
DBSnapshotIdentifier=snapshot_identifier,
AttributeName='restore',
Expand Down
17 changes: 17 additions & 0 deletions lambda/snapshots_tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,23 @@ class SnapshotToolException(Exception):
pass


def get_kms_type(kmskeyid,REGION):

keys = re.findall(r'([^\/]+$)',kmskeyid)
client = boto3.client('kms', region_name=REGION)

for key in keys:
response = client.describe_key(
KeyId=key
)
#print(response)
kms_owner = response['KeyMetadata']['KeyManager']

if kms_owner != 'AWS':

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this with return kms_owner == 'AWS' ?

return False
else:
return True

def search_tag_copydbsnapshot(response):
# Takes a list_tags_for_resource response and searches for our CopyDBSnapshot tag
try:
Expand Down
1 change: 0 additions & 1 deletion lambda/take_snapshots_rds/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def lambda_handler(event, context):
filtered_snapshots = get_own_snapshots_source(PATTERN, paginate_api_call(client, 'describe_db_snapshots', 'DBSnapshots'), BACKUP_INTERVAL)

for db_instance in filtered_instances:

timestamp_format = now.strftime(TIMESTAMP_FORMAT)

if requires_backup(BACKUP_INTERVAL, db_instance, filtered_snapshots):
Expand Down