@@ -59,7 +59,8 @@ def __init__(
5959 private samples is large. It requires the installation of `faiss-gpu` or `faiss-cpu` package. See
6060 https://faiss.ai/
6161 :type backend: str, optional
62- :param vote_normalization_level: The level of normalization for the votes. It should be one of the following:
62+ :param vote_normalization_level: The level of normalization for the votes. This corresponds to the granularity
63+ of the neighboring definition in differential privacy (DP). It should be one of the following:
6364 "sample" (normalize the votes from each private sample to have l2 norm = 1), "client" (normalize the votes
6465 from all private samples of the same client to have l2 norm = 1). Defaults to "sample"
6566 :type vote_normalization_level: str, optional
@@ -197,23 +198,25 @@ def compute_histogram(self, priv_data, syn_data):
197198 )
198199 self ._log_voting_details (priv_data = priv_data , syn_data = syn_data , ids = ids )
199200
200- priv_data = priv_data .reset_index (drop = True )
201+ count = np .zeros (shape = syn_embedding .shape [0 ], dtype = np .float32 )
202+
201203 if self ._vote_normalization_level == "client" :
204+ priv_data = priv_data .reset_index (drop = True )
202205 priv_data_list = priv_data .split_by_client ()
206+ for sub_priv_data in priv_data_list :
207+ sub_count = np .zeros (shape = syn_embedding .shape [0 ], dtype = np .float32 )
208+ sub_ids = ids [sub_priv_data .data_frame .index ]
209+ counter = Counter (list (sub_ids .flatten ()))
210+ sub_count [list (counter .keys ())] = list (counter .values ())
211+ sub_count /= np .linalg .norm (sub_count )
212+ count += sub_count
203213 elif self ._vote_normalization_level == "sample" :
204- priv_data_list = priv_data .split_by_index ()
214+ counter = Counter (list (ids .flatten ()))
215+ count [list (counter .keys ())] = list (counter .values ())
216+ count /= np .sqrt (self ._num_nearest_neighbors )
205217 else :
206218 raise ValueError (f"Unknown vote normalization level: { self ._vote_normalization_level } " )
207219
208- count = np .zeros (shape = syn_embedding .shape [0 ], dtype = np .float32 )
209- for sub_priv_data in priv_data_list :
210- sub_count = np .zeros (shape = syn_embedding .shape [0 ], dtype = np .float32 )
211- sub_ids = ids [sub_priv_data .data_frame .index ]
212- counter = Counter (list (sub_ids .flatten ()))
213- sub_count [list (counter .keys ())] = list (counter .values ())
214- sub_count /= np .linalg .norm (sub_count )
215- count += sub_count
216-
217220 syn_data .data_frame [CLEAN_HISTOGRAM_COLUMN_NAME ] = count
218221
219222 execution_logger .info (
0 commit comments