Skip to content

Commit 5d4910b

Browse files
Sample script to compare (#428)
* Sample script to compare
1 parent 55cbb98 commit 5d4910b

File tree

3 files changed

+121
-0
lines changed

3 files changed

+121
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
2+
# Salt Compare Tool
3+
4+
A lightweight Python script to compare encrypted and unencrypted salt files stored in an S3 bucket.
5+
6+
## Description
7+
8+
This script fetches two S3 objects (e.g. encrypted and unencrypted salts) and compares them. It validates basic input rules like prefix formatting and key pattern before proceeding.
9+
10+
## Usage
11+
12+
### Run the script:
13+
14+
Login to AWS account
15+
`pip install requirements.txt`
16+
`python script.py <key> <bucket> <region_name> [prefix]`
17+
18+
- `encrypted_file` – Required. Must start with `salt`. Example: `salts/encrypted/12_private/salts.txt.1745532777048` (To query multiple files you can use `salts/encrypted/12_private/*`)
19+
- `bucket` – Required. Name of the S3 bucket.
20+
- `region_name` – Required. AWS region of the S3 bucket (e.g. `us-east-1`)
21+
- `prefix` – Optional. S3 path prefix. If provided, it **must end with `/`**.
22+
23+
## For Other Decryption Comparisons
24+
25+
You can use the **same logic** for other types of decryption and comparison. The only change is in how the **unencrypted file name** is generated in salt_compare.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
cryptography
2+
boto3
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import json
2+
import base64
3+
from typing import IO
4+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
5+
from cryptography.hazmat.backends import default_backend
6+
import boto3
7+
import sys
8+
9+
class AesGcm:
10+
@staticmethod
11+
def decrypt(encrypted_data: bytes, nonce: bytes, key: bytes):
12+
if len(nonce) != 12:
13+
raise ValueError("Nonce must be 12 bytes for AES-GCM")
14+
cipher = Cipher(algorithms.AES(key), modes.GCM(nonce), backend=default_backend())
15+
decryptor = cipher.decryptor()
16+
try:
17+
return decryptor.update(encrypted_data) + decryptor.finalize()
18+
except Exception:
19+
raise ValueError("Invalid GCM tag during decryption")
20+
21+
def _get_encryption_secret(key_id, bucket, prefix, region_name):
22+
print("Fetching secret key for ", key_id)
23+
s3 = boto3.client('s3', region_name=region_name)
24+
response = s3.get_object(Bucket=bucket, Key=f"{prefix}cloud_encryption_keys/cloud_encryption_keys.json")
25+
data = json.load(response['Body'])
26+
_map = {item['id']: item for item in data}
27+
return _map.get(key_id).get('secret')
28+
29+
def _decrypt_input_stream(input_stream: IO[bytes], bucket, prefix, region_name) -> str:
30+
try:
31+
data = json.load(input_stream)
32+
except json.JSONDecodeError as e:
33+
raise ValueError(f"Failed to parse JSON: {e}")
34+
key_id = data.get("key_id")
35+
encrypted_payload_b64 = data.get("encrypted_payload")
36+
if key_id is None or encrypted_payload_b64 is None:
37+
raise ValueError("Failed to parse JSON")
38+
39+
decryption_key = _get_encryption_secret(key_id, bucket, prefix, region_name)
40+
try:
41+
secret_bytes = base64.b64decode(decryption_key)
42+
encrypted_bytes = base64.b64decode(encrypted_payload_b64)
43+
nonce = encrypted_bytes[:12]
44+
ciphertext = encrypted_bytes[12:-16]
45+
auth_tag = encrypted_bytes[-16:]
46+
cipher = Cipher(algorithms.AES(secret_bytes), modes.GCM(nonce, auth_tag), backend=default_backend())
47+
decryptor = cipher.decryptor()
48+
decrypted_bytes = decryptor.update(ciphertext) + decryptor.finalize()
49+
return decrypted_bytes.decode("utf-8")
50+
except Exception as e:
51+
raise ValueError(f"An error occurred during decryption: {e}")
52+
53+
def salt_compare(key, prefix, bucket, region_name):
54+
s3 = boto3.client('s3', region_name=region_name)
55+
key = f"{prefix}{key}"
56+
print("Key is ", key)
57+
base_path = '/'.join(key.split('/')[:-3])
58+
file_name = key.split('/')[-1:][0]
59+
unencrypted = f'{base_path}/{file_name}'
60+
print(f"Comparing {key} with {unencrypted}")
61+
response = s3.get_object(Bucket=bucket, Key=key)
62+
encrypted = _decrypt_input_stream(response['Body'], bucket=bucket, prefix=prefix, region_name=region_name)
63+
response = s3.get_object(Bucket=bucket, Key=unencrypted)
64+
unencrypted = response['Body'].read().decode('utf-8')
65+
return (encrypted==unencrypted)
66+
67+
def _get_most_recent_files(bucket, prefix, key):
68+
s3 = boto3.client("s3")
69+
paginator = s3.get_paginator("list_objects_v2")
70+
page_iterator = paginator.paginate(Bucket=bucket, Prefix=f"{prefix}{key[:-2]}/")
71+
n, all_files = 5 , []
72+
for i, page in enumerate(page_iterator):
73+
if i >= n:
74+
break
75+
all_files.extend(page.get("Contents", []))
76+
recent_files = sorted(all_files, key=lambda x: x["LastModified"], reverse=True)
77+
recent_files = list(map(lambda x: x['Key'], recent_files))
78+
recent_files = list(filter(lambda x: "metadata" not in x, recent_files))
79+
return recent_files[:10]
80+
81+
if __name__ == '__main__':
82+
encrypted_file = sys.argv[1]
83+
bucket = sys.argv[2]
84+
region_name = sys.argv[3]
85+
prefix = sys.argv[4] if len(sys.argv) > 4 else ''
86+
if prefix != '' and prefix[-1]!='/':
87+
raise "prefix should terminate with /"
88+
if not encrypted_file.startswith("salt"):
89+
raise "only salts supported"
90+
if encrypted_file[-2:] == '/*':
91+
for recent in _get_most_recent_files(bucket=bucket, prefix=prefix, key=encrypted_file):
92+
print(salt_compare(key=recent, prefix=prefix, bucket=bucket, region_name=region_name))
93+
else:
94+
print(salt_compare(key=encrypted_file, prefix=prefix, bucket=bucket, region_name=region_name))

0 commit comments

Comments
 (0)