Skip to content

Commit 4428ee5

Browse files
committed
improve the efficiency of sample-level voting
1 parent d9e8c62 commit 4428ee5

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

pe/histogram/nearest_neighbors.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def __init__(
5959
private samples is large. It requires the installation of `faiss-gpu` or `faiss-cpu` package. See
6060
https://faiss.ai/
6161
:type backend: str, optional
62-
:param vote_normalization_level: The level of normalization for the votes. It should be one of the following:
62+
:param vote_normalization_level: The level of normalization for the votes. This corresponds to the granularity
63+
of the neighboring definition in differential privacy (DP). It should be one of the following:
6364
"sample" (normalize the votes from each private sample to have l2 norm = 1), "client" (normalize the votes
6465
from all private samples of the same client to have l2 norm = 1). Defaults to "sample"
6566
:type vote_normalization_level: str, optional
@@ -197,23 +198,26 @@ def compute_histogram(self, priv_data, syn_data):
197198
)
198199
self._log_voting_details(priv_data=priv_data, syn_data=syn_data, ids=ids)
199200

200-
priv_data = priv_data.reset_index(drop=True)
201+
count = np.zeros(shape=syn_embedding.shape[0], dtype=np.float32)
202+
201203
if self._vote_normalization_level == "client":
204+
priv_data = priv_data.reset_index(drop=True)
202205
priv_data_list = priv_data.split_by_client()
206+
for sub_priv_data in priv_data_list:
207+
sub_count = np.zeros(shape=syn_embedding.shape[0], dtype=np.float32)
208+
sub_ids = ids[sub_priv_data.data_frame.index]
209+
counter = Counter(list(sub_ids.flatten()))
210+
sub_count[list(counter.keys())] = list(counter.values())
211+
sub_count /= np.linalg.norm(sub_count)
212+
count += sub_count
203213
elif self._vote_normalization_level == "sample":
204-
priv_data_list = priv_data.split_by_index()
214+
counter = Counter(list(ids.flatten()))
215+
count = np.zeros(shape=syn_embedding.shape[0], dtype=np.float32)
216+
count[list(counter.keys())] = list(counter.values())
217+
count /= np.sqrt(self._num_nearest_neighbors)
205218
else:
206219
raise ValueError(f"Unknown vote normalization level: {self._vote_normalization_level}")
207220

208-
count = np.zeros(shape=syn_embedding.shape[0], dtype=np.float32)
209-
for sub_priv_data in priv_data_list:
210-
sub_count = np.zeros(shape=syn_embedding.shape[0], dtype=np.float32)
211-
sub_ids = ids[sub_priv_data.data_frame.index]
212-
counter = Counter(list(sub_ids.flatten()))
213-
sub_count[list(counter.keys())] = list(counter.values())
214-
sub_count /= np.linalg.norm(sub_count)
215-
count += sub_count
216-
217221
syn_data.data_frame[CLEAN_HISTOGRAM_COLUMN_NAME] = count
218222

219223
execution_logger.info(

0 commit comments

Comments
 (0)