Skip to content

Commit e0e902f

Browse files
Refactor encryption, DP, and data handling; add tests
Replaced pickle-based serialization with JSON-safe methods for encrypted payloads and public keys, updated secure channel to use AES-GCM with HKDF key derivation and AAD, and added replay protection. Differential privacy engine now supports explicit sensitivity setting and step accounting. Data loading supports non-IID splits. Added unit tests for cryptographic protocols, homomorphic encryption, and differential privacy. Various interface changes to model update payloads for consistency.
1 parent 2896e32 commit e0e902f

File tree

12 files changed

+314
-93
lines changed

12 files changed

+314
-93
lines changed

.github/workflows/ci.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
pull_request:
6+
7+
jobs:
8+
test:
9+
runs-on: ubuntu-latest
10+
strategy:
11+
fail-fast: false
12+
matrix:
13+
python-version: ["3.9", "3.10", "3.11"]
14+
steps:
15+
- uses: actions/checkout@v4
16+
- name: Set up Python
17+
uses: actions/setup-python@v5
18+
with:
19+
python-version: ${{ matrix.python-version }}
20+
- name: Install dependencies
21+
run: |
22+
python -m pip install --upgrade pip
23+
pip install -r requirements.txt
24+
- name: Run tests
25+
run: |
26+
pytest -q

requirements.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Core Dependencies
2+
numpy>=1.24.0
3+
scikit-learn>=1.3.0
4+
pandas>=2.0.0
5+
6+
# Post-Quantum Cryptography
7+
dilithium-py>=1.1.0
8+
kyber-py>=0.1.0
9+
cryptography>=41.0.0
10+
11+
# Homomorphic Encryption
12+
phe>=1.5.0
13+
14+
# Visualization
15+
matplotlib>=3.7.0
16+
17+
# Optional
18+
scipy>=1.11.0
19+
20+
# Dev / Tests
21+
pytest>=7.4.0

src/benchmark.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def measure_payload_size(payload):
5151
"""
5252
Measure the size of a payload in bytes.
5353
"""
54-
import pickle
55-
return len(pickle.dumps(payload))
54+
import json
55+
return len(json.dumps(payload, sort_keys=True, default=str).encode("utf-8"))
5656

5757

5858
def run_experiment(use_he, num_clients=3, num_rounds=5):
@@ -143,19 +143,22 @@ def run_experiment(use_he, num_clients=3, num_rounds=5):
143143
# Prepare Update
144144
model_update = {
145145
"client_id": cid,
146-
"encrypted_gradients": {"W": new_W, "b": new_b},
146+
"model_params": {"W": new_W, "b": new_b},
147147
"num_samples": len(X_local)
148148
}
149149

150150
# Secure Send (measure time and size)
151151
info = registry_info[cid]
152+
info.setdefault("counter", 0)
153+
info["counter"] += 1
152154

153155
encrypt_start = time.time()
154156
payload = client.secure_send_update(
155157
model_update,
156158
info["server_kyber_pk"],
157159
info["session_key"],
158-
use_he=use_he
160+
use_he=use_he,
161+
msg_counter=info["counter"],
159162
)
160163
encrypt_time = time.time() - encrypt_start
161164
round_encryption_time += encrypt_time

src/data_utils.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sklearn.preprocessing import StandardScaler, LabelEncoder
88
from sklearn.datasets import fetch_openml
99

10-
def load_and_preprocess_data(n_clients: int):
10+
def load_and_preprocess_data(n_clients: int, split_strategy: str = "iid", non_iid_label_skew: bool = False):
1111
"""
1212
Loads the Adult Income dataset, preprocesses it, and splits it for federated clients.
1313
@@ -38,24 +38,32 @@ def load_and_preprocess_data(n_clients: int):
3838
numeric_features = ['age', 'capital-gain', 'capital-loss', 'hours-per-week']
3939
X = X[numeric_features]
4040

41-
# 4. Scale features
42-
scaler = StandardScaler()
43-
X = scaler.fit_transform(X)
41+
# Split into train and test FIRST (avoid preprocessing leakage)
42+
X_train_df, X_test_df, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
4443

45-
# Split into train and test
46-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
44+
# 4. Scale features (fit on train only)
45+
scaler = StandardScaler()
46+
X_train = scaler.fit_transform(X_train_df)
47+
X_test = scaler.transform(X_test_df)
4748

4849
# Split training data among clients
4950
# We'll do a simple IID split (random shuffle is implicit in train_test_split)
5051
client_datasets = []
5152
chunk_size = len(X_train) // n_clients
53+
indices = np.arange(len(X_train))
54+
55+
# Optional simple non-IID label skew for demo purposes:
56+
# sort by label and split contiguous blocks, producing label-imbalanced clients.
57+
if non_iid_label_skew or split_strategy.lower() in ("non-iid", "noniid", "label_skew"):
58+
indices = indices[np.argsort(np.array(y_train))]
5259

5360
for i in range(n_clients):
5461
start = i * chunk_size
5562
end = (i + 1) * chunk_size
56-
X_c = X_train[start:end]
57-
y_c = y_train[start:end]
58-
client_datasets.append((X_c, y_c.values)) # Convert y to numpy array
63+
idx = indices[start:end]
64+
X_c = X_train[idx]
65+
y_c = np.array(y_train)[idx]
66+
client_datasets.append((X_c, y_c)) # y already numpy
5967

6068
print(f"Data loaded. {n_clients} clients, {chunk_size} samples per client.")
6169
return client_datasets, (X_test, y_test)

src/differential_privacy.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class DifferentialPrivacy:
1616
Based on: Dwork et al., "Differential Privacy: A Survey of Results"
1717
"""
1818

19-
def __init__(self, epsilon: float = 1.0, delta: float = 1e-5,
20-
sensitivity: float = 2.0, noise_type: str = 'laplace'):
19+
def __init__(self, epsilon: float = 1.0, delta: float = 1e-5,
20+
sensitivity: float = 1.0, noise_type: str = 'gaussian'):
2121
"""
2222
Initialize DP mechanism.
2323
@@ -70,7 +70,21 @@ def _calculate_scale(self) -> float:
7070
else:
7171
raise ValueError(f"Unknown noise type: {self.noise_type}")
7272

73-
def add_noise(self, gradient: np.ndarray) -> np.ndarray:
73+
def set_sensitivity(self, sensitivity: float) -> None:
74+
"""
75+
Update sensitivity (typically tied to clipping norm) and recompute scale.
76+
"""
77+
self.sensitivity = float(sensitivity)
78+
self.scale = self._calculate_scale()
79+
80+
def account_step(self, epsilon_spent: float | None = None) -> None:
81+
"""
82+
Account for one privacy mechanism application.
83+
"""
84+
self.privacy_spent += float(self.epsilon if epsilon_spent is None else epsilon_spent)
85+
self.rounds_executed += 1
86+
87+
def add_noise(self, gradient: np.ndarray, *, account: bool = False) -> np.ndarray:
7488
"""
7589
Add differential privacy noise to gradient.
7690
@@ -88,13 +102,9 @@ def add_noise(self, gradient: np.ndarray) -> np.ndarray:
88102
# Gaussian distribution: N(0, scale^2)
89103
noise = np.random.normal(0, self.scale, size=gradient.shape)
90104

91-
# Add noise to gradient
92105
noisy_gradient = gradient + noise
93-
94-
# Update privacy budget (add epsilon spent in this round)
95-
self.privacy_spent += self.epsilon
96-
self.rounds_executed += 1
97-
106+
if account:
107+
self.account_step()
98108
return noisy_gradient
99109

100110
def add_noise_to_dict(self, gradient_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:

src/federated_client.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pickle
21
import numpy as np
32
from typing import Dict, Any, Union
43
from pqc_auth import PQCAuthenticator
@@ -81,15 +80,26 @@ def apply_differential_privacy(self,
8180

8281
# 3. Add Noise using the DP Engine
8382
# The engine handles the noise generation based on epsilon
84-
noisy_delta = self.dp_engine.add_noise(delta)
83+
# Tie DP sensitivity to clipping bound and avoid double-accounting:
84+
# - We add noise to multiple tensors (W and b), but account once per "client update".
85+
self.dp_engine.set_sensitivity(clipping_norm)
86+
noisy_delta = self.dp_engine.add_noise(delta, account=False)
8587

8688
# 4. Return differentially private weights
8789
return global_weights + noisy_delta
8890

89-
def secure_send_update(self, model_update: Dict[str, Any],
90-
server_kyber_pk: str,
91+
def account_privacy_step(self) -> None:
92+
"""
93+
Account for one DP mechanism application per client update.
94+
"""
95+
if self.use_dp:
96+
self.dp_engine.account_step()
97+
98+
def secure_send_update(self, model_update: Dict[str, Any],
99+
server_kyber_pk: str,
91100
session_key: bytes = None,
92-
use_he: bool = True) -> Dict[str, Any]:
101+
use_he: bool = True,
102+
msg_counter: int = 0) -> Dict[str, Any]:
93103
"""
94104
Encrypt and sign the model update.
95105
@@ -116,8 +126,8 @@ def secure_send_update(self, model_update: Dict[str, Any],
116126
temp_he.public_key = self.he_public_key
117127

118128
# Encrypt gradients with Paillier
119-
encrypted_gradients = {}
120-
for param_name, param_value in model_update['encrypted_gradients'].items():
129+
encrypted_params = {}
130+
for param_name, param_value in model_update['model_params'].items():
121131

122132
# --- FIX: Ensure input is always a NumPy array ---
123133
if isinstance(param_value, np.ndarray):
@@ -127,13 +137,13 @@ def secure_send_update(self, model_update: Dict[str, Any],
127137
# -----------------------------------------------
128138

129139
encrypted_list = temp_he.encrypt_vector(vec, self.he_public_key)
130-
# Serialize the encrypted list for transmission
131-
encrypted_gradients[param_name] = pickle.dumps(encrypted_list)
140+
# Serialize encrypted vector to JSON-safe payload
141+
encrypted_params[param_name] = temp_he.serialize_encrypted_vector(encrypted_list)
132142

133143
# Replace plaintext gradients with encrypted ones
134144
model_update_to_send = {
135145
"client_id": model_update["client_id"],
136-
"encrypted_gradients": encrypted_gradients,
146+
"model_params": encrypted_params,
137147
"num_samples": model_update["num_samples"]
138148
}
139149
else:
@@ -143,11 +153,15 @@ def secure_send_update(self, model_update: Dict[str, Any],
143153
signed_package = self.authenticator.sign_update(model_update_to_send, self.private_key)
144154

145155
# Step 3: Encrypt the Signed Package (Confidentiality)
146-
data_bytes = pickle.dumps(signed_package)
147-
encrypted_data = self.secure_channel.encrypt_message(data_bytes, current_session_key)
156+
encrypted_data = self.secure_channel.encrypt_json(
157+
signed_package,
158+
session_key=current_session_key,
159+
aad={"client_id": self.client_id, "counter": int(msg_counter), "type": "model_update"},
160+
)
148161

149162
# Final Payload
150163
payload_structure['client_id'] = self.client_id
164+
payload_structure['counter'] = int(msg_counter)
151165
payload_structure['encrypted_payload'] = encrypted_data
152166
payload_structure['session_key'] = current_session_key
153167

src/federated_server.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pickle
21
from typing import Dict, Any
32
import numpy as np
43

@@ -59,6 +58,7 @@ def register_client(self, client_id: str, client_public_key: str) -> Dict[str, A
5958
self.clients[client_id] = {
6059
"dilithium_pk": client_public_key,
6160
"session_key": None,
61+
"last_counter": -1,
6262
}
6363

6464
print(f"✓ Registered client '{client_id}'")
@@ -82,6 +82,9 @@ def receive_update(self, payload: Dict[str, Any]) -> Dict[str, Any]:
8282
return {"status": "error", "reason": "unregistered_client"}
8383

8484
client_meta = self.clients[client_id]
85+
msg_counter = int(payload.get("counter", -1))
86+
if msg_counter <= client_meta.get("last_counter", -1):
87+
return {"status": "error", "reason": "replay_detected"}
8588

8689
# 1) Establish/retrieve session key
8790
kyber_ciphertext = payload.get("kyber_ciphertext")
@@ -99,14 +102,15 @@ def receive_update(self, payload: Dict[str, Any]) -> Dict[str, Any]:
99102
# 2) Decrypt the signed update with AES-GCM
100103
encrypted_payload = payload.get("encrypted_payload")
101104
try:
102-
signed_bytes = self.secure_channel.decrypt_message(
103-
encrypted_payload, session_key
105+
signed_update = self.secure_channel.decrypt_json(
106+
encrypted_payload,
107+
session_key=session_key,
108+
aad={"client_id": client_id, "counter": msg_counter, "type": "model_update"},
104109
)
105110
except Exception as e:
106111
print(f"✗ Decryption failed for client '{client_id}': {e}")
107112
return {"status": "error", "reason": "decryption_failed"}
108-
109-
signed_update = pickle.loads(signed_bytes)
113+
client_meta["last_counter"] = msg_counter
110114

111115
# 3) Verify Dilithium signature
112116
public_key_hex = client_meta["dilithium_pk"]
@@ -118,17 +122,17 @@ def receive_update(self, payload: Dict[str, Any]) -> Dict[str, Any]:
118122

119123
# 4) Extract model update
120124
model_update = signed_update["model_update"]
121-
client_update = model_update.get("encrypted_gradients")
125+
client_update = model_update.get("model_params")
122126
num_samples = model_update.get("num_samples", 0)
123127

124128
# 5) Process based on HE mode
125129
if self.use_he:
126-
# HE Mode: client_update contains encrypted values (pickled)
127-
# Deserialize the encrypted values
130+
# HE Mode: client_update contains JSON-safe encrypted vectors
128131
deserialized_update = {}
129-
for param_name, encrypted_bytes in client_update.items():
130-
deserialized_update[param_name] = pickle.loads(encrypted_bytes)
131-
132+
for param_name, encrypted_payload_list in client_update.items():
133+
deserialized_update[param_name] = self.he_manager.deserialize_encrypted_vector(
134+
self.he_manager.public_key, encrypted_payload_list
135+
)
132136
self.he_aggregator.add_client_update(deserialized_update, num_samples)
133137
else:
134138
# Plaintext Mode: client_update contains raw arrays

0 commit comments

Comments
 (0)