1- import pickle
21import numpy as np
32from typing import Dict , Any , Union
43from pqc_auth import PQCAuthenticator
@@ -81,15 +80,26 @@ def apply_differential_privacy(self,
8180
8281 # 3. Add Noise using the DP Engine
8382 # The engine handles the noise generation based on epsilon
84- noisy_delta = self .dp_engine .add_noise (delta )
83+ # Tie DP sensitivity to clipping bound and avoid double-accounting:
84+ # - We add noise to multiple tensors (W and b), but account once per "client update".
85+ self .dp_engine .set_sensitivity (clipping_norm )
86+ noisy_delta = self .dp_engine .add_noise (delta , account = False )
8587
8688 # 4. Return differentially private weights
8789 return global_weights + noisy_delta
8890
89- def secure_send_update (self , model_update : Dict [str , Any ],
90- server_kyber_pk : str ,
91+ def account_privacy_step (self ) -> None :
92+ """
93+ Account for one DP mechanism application per client update.
94+ """
95+ if self .use_dp :
96+ self .dp_engine .account_step ()
97+
98+ def secure_send_update (self , model_update : Dict [str , Any ],
99+ server_kyber_pk : str ,
91100 session_key : bytes = None ,
92- use_he : bool = True ) -> Dict [str , Any ]:
101+ use_he : bool = True ,
102+ msg_counter : int = 0 ) -> Dict [str , Any ]:
93103 """
94104 Encrypt and sign the model update.
95105
@@ -116,8 +126,8 @@ def secure_send_update(self, model_update: Dict[str, Any],
116126 temp_he .public_key = self .he_public_key
117127
118128 # Encrypt gradients with Paillier
119- encrypted_gradients = {}
120- for param_name , param_value in model_update ['encrypted_gradients ' ].items ():
129+ encrypted_params = {}
130+ for param_name , param_value in model_update ['model_params ' ].items ():
121131
122132 # --- FIX: Ensure input is always a NumPy array ---
123133 if isinstance (param_value , np .ndarray ):
@@ -127,13 +137,13 @@ def secure_send_update(self, model_update: Dict[str, Any],
127137 # -----------------------------------------------
128138
129139 encrypted_list = temp_he .encrypt_vector (vec , self .he_public_key )
130- # Serialize the encrypted list for transmission
131- encrypted_gradients [param_name ] = pickle . dumps (encrypted_list )
140+ # Serialize encrypted vector to JSON-safe payload
141+ encrypted_params [param_name ] = temp_he . serialize_encrypted_vector (encrypted_list )
132142
133143 # Replace plaintext gradients with encrypted ones
134144 model_update_to_send = {
135145 "client_id" : model_update ["client_id" ],
136- "encrypted_gradients " : encrypted_gradients ,
146+ "model_params " : encrypted_params ,
137147 "num_samples" : model_update ["num_samples" ]
138148 }
139149 else :
@@ -143,11 +153,15 @@ def secure_send_update(self, model_update: Dict[str, Any],
143153 signed_package = self .authenticator .sign_update (model_update_to_send , self .private_key )
144154
145155 # Step 3: Encrypt the Signed Package (Confidentiality)
146- data_bytes = pickle .dumps (signed_package )
147- encrypted_data = self .secure_channel .encrypt_message (data_bytes , current_session_key )
156+ encrypted_data = self .secure_channel .encrypt_json (
157+ signed_package ,
158+ session_key = current_session_key ,
159+ aad = {"client_id" : self .client_id , "counter" : int (msg_counter ), "type" : "model_update" },
160+ )
148161
149162 # Final Payload
150163 payload_structure ['client_id' ] = self .client_id
164+ payload_structure ['counter' ] = int (msg_counter )
151165 payload_structure ['encrypted_payload' ] = encrypted_data
152166 payload_structure ['session_key' ] = current_session_key
153167
0 commit comments