Skip to content

Commit 4bc23b3

Browse files
committed
addressed Marcelo's comments and added improved readability
1 parent 16b489b commit 4bc23b3

File tree

1 file changed

+84
-41
lines changed

1 file changed

+84
-41
lines changed

src/midst_toolkit/evaluation/privacy/batched_eir.py

Lines changed: 84 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Iterable
2+
from typing import Any, Literal
23

34
import numpy as np
45
import pandas as pd
@@ -10,50 +11,51 @@
1011

1112
def _column_entropy(labels: list | np.ndarray) -> np.number:
1213
"""Compute the entropy of a single column."""
13-
value, counts = np.unique(np.round(labels), return_counts=True)
14+
_, counts = np.unique(np.round(labels), return_counts=True)
1415
return entropy(counts)
1516

1617

1718
def batched_reference_knn(
1819
query_df: pd.DataFrame,
1920
reference_df: pd.DataFrame,
20-
cat_cols: list[int],
21-
nn_dist: str,
21+
categorical_columns: list[int],
22+
nn_distance_metric: Literal["gower", "euclid"],
2223
weights: np.ndarray,
2324
ref_batch_size: int = 128,
2425
show_progress: bool = True,
2526
) -> np.ndarray:
2627
"""
27-
Compute k-nearest neighbor distances from query rows to reference rows in a memory-efficient way.
28+
Compute nearest neighbor distances from the points in query_df to reference_df in a memory-efficient way.
2829
2930
Instead of comparing all query rows to all reference rows at once, the reference DataFrame
3031
is split into batches. For each batch:
31-
1. Compute the distances from all query rows to the current batch.
32+
1. Compute the distances from all query rows to the current reference_df batch.
3233
2. Keep track of the smallest distance per query row across all batches.
3334
3435
Args:
35-
query_df : The data points for which kNN distances are computed.
36+
query_df : The data points for which nearest neighbor distances are computed.
3637
reference_df : The data points used as the reference for computing distances.
37-
cat_cols : Indices of categorical columns.
38-
nn_dist : Distance metric to use for nearest neighbor computation.
38+
categorical_columns : Indices of categorical columns.
39+
nn_distance_metric : Distance metric to use for nearest neighbor distance computation. Possible values are the
40+
Gower distance metric ('gower') and the Euclidean distance metric ('euclid').
3941
weights : Feature weights to apply when computing distances.
4042
ref_batch_size : Number of reference rows per batch.
4143
show_progress : Whether to display a progress bar over reference batches.
4244
43-
Returns :
44-
Array of minimum distances per query row after considering all reference batches.
45+
Returns:
46+
Array of nearest neighbor distance per query row after considering all reference batches.
4547
"""
4648
n_query = len(query_df)
4749

48-
# best distances so far = +inf
49-
best_d = np.full(n_query, np.inf, dtype=float)
50+
# Initizalizing a list of best distances with np.inf so they can be replaced with the actual best distances later.
51+
nearest_neighbor_distance = np.full(n_query, np.inf, dtype=float)
5052

5153
iterator: Iterable[int]
5254
if show_progress:
5355
iterator = tqdm(
5456
range(0, len(reference_df), ref_batch_size),
5557
total=(len(reference_df) + ref_batch_size - 1) // ref_batch_size,
56-
desc="Computing ref-batched kNN distances",
58+
desc="Computing nearest neighbor distances from real/holdout dataset to synthetic dataset.",
5759
)
5860
else:
5961
iterator = range(0, len(reference_df), ref_batch_size)
@@ -62,35 +64,62 @@ def batched_reference_knn(
6264
end = min(start + ref_batch_size, len(reference_df))
6365
ref_batch = reference_df.iloc[start:end]
6466

65-
# compute distances to this reference batch (k=1 → index 0)
66-
d_batch = _knn_distance(query_df, ref_batch, cat_cols, 1, nn_dist, weights)[0]
67+
# compute distances for each row of the reference batch to its closest neigbour in ref_batch
68+
# hardcoding of k=1 refers to only needing to compute the distance to the closest neighbor.
69+
batch_distances = _knn_distance(query_df, ref_batch, categorical_columns, 1, nn_distance_metric, weights)[0]
6770

6871
# keep smallest per query row
69-
best_d = np.minimum(best_d, d_batch)
72+
nearest_neighbor_distance = np.minimum(nearest_neighbor_distance, batch_distances)
7073

71-
return best_d
74+
return nearest_neighbor_distance
7275

7376

7477
class EpsilonIdentifiability(MetricClass): # type: ignore[misc]
7578
def name(self) -> str:
76-
"""Return the name of the metric."""
79+
"""
80+
Returns the identifier of the metric.
81+
82+
Returns:
83+
"eps_risk"
84+
"""
7785
return "eps_risk"
7886

7987
def type(self) -> str:
80-
"""Return the type of the metric."""
88+
"""
89+
Returns the type of the evaluation metric.
90+
91+
Returns:
92+
"privacy"
93+
"""
8194
return "privacy"
8295

8396
def evaluate(self) -> dict:
84-
"""Compute the Epsilon Identifiability Risk and Privacy Loss."""
85-
real = np.asarray(self.real_data)
86-
no, x_dim = real.shape
97+
"""
98+
Compute epsilon-identifiability risk and privacy loss.
99+
100+
The epsilon-identifiability risk (eps_risk) is defined as the fraction of real
101+
records whose nearest neighbor in the synthetic dataset is closer than their
102+
nearest neighbor in the real dataset, using an entropy-weighted distance metric.
103+
104+
If holdout data is provided, the privacy loss (priv_loss) is computed as the
105+
difference between the identifiability risk on the training data and the
106+
identifiability risk on the holdout data.
107+
108+
Returns:
109+
dict:
110+
- 'eps_risk': Fraction of real records vulnerable to re-identification.
111+
- 'priv_loss': Difference between training and holdout identifiability risks
112+
(only present if holdout data is not None).
113+
"""
114+
np_real_data = np.asarray(self.real_data)
115+
real_size, n_feautures = np_real_data.shape
87116

88117
# Column entropies → weights (inverted)
89-
weights = [_column_entropy(real[:, i]) for i in range(x_dim)]
118+
weights = [_column_entropy(np_real_data[:, feauture]) for feauture in range(n_feautures)]
90119
weights_adjusted = 1 / (np.array(weights) + 1e-16)
91120

92121
# INTERNAL KNN: REAL → REAL
93-
in_dists = _knn_distance(
122+
internal_distances = _knn_distance(
94123
self.real_data,
95124
self.real_data,
96125
self.cat_cols,
@@ -100,37 +129,37 @@ def evaluate(self) -> dict:
100129
)[0]
101130

102131
# EXTERNAL KNN: REAL → SYNTHETIC (safe to batch reference)
103-
ext_dists = batched_reference_knn(
132+
external_distances = batched_reference_knn(
104133
self.real_data,
105134
self.synt_data,
106135
self.cat_cols,
107136
self.nn_dist,
108137
weights_adjusted,
109138
)
110139

111-
r_diff = ext_dists - in_dists
112-
identifiability = np.sum(r_diff < 0) / float(no)
113-
self.results["eps_risk"] = identifiability
140+
real_data_distance_differences = external_distances - internal_distances
141+
identifiability_risk = np.sum(real_data_distance_differences < 0) / float(real_size)
142+
self.results["eps_risk"] = identifiability_risk
114143

115144
if self.hout_data is not None:
116145
# INTERNAL: HOUT → HOUT (original logic)
117-
hout_in = _knn_distance(self.hout_data, self.hout_data, self.cat_cols, 1, self.nn_dist, weights_adjusted)[
118-
0
119-
]
146+
hout_internal_distances = _knn_distance(
147+
self.hout_data, self.hout_data, self.cat_cols, 1, self.nn_dist, weights_adjusted
148+
)[0]
120149

121150
# EXTERNAL: HOUT → SYNTHETIC (batched)
122-
hout_ext = batched_reference_knn(
151+
hout_external_distances = batched_reference_knn(
123152
self.hout_data,
124153
self.synt_data,
125154
self.cat_cols,
126155
self.nn_dist,
127156
weights_adjusted,
128157
)
129158

130-
hout_diff = hout_ext - hout_in
131-
hout_val = np.sum(hout_diff < 0) / float(len(self.hout_data))
159+
holdout_data_distance_differences = hout_external_distances - hout_internal_distances
160+
hout_identifiability_risk = np.sum(holdout_data_distance_differences < 0) / float(len(self.hout_data))
132161

133-
self.results["priv_loss"] = self.results["eps_risk"] - hout_val
162+
self.results["priv_loss"] = self.results["eps_risk"] - hout_identifiability_risk
134163

135164
return self.results
136165

@@ -141,27 +170,41 @@ def format_output(self) -> str:
141170
string += f"\n| Privacy loss (diff. in eps. risk) : {self.results['priv_loss']:.4f} |"
142171
return string
143172

144-
def normalize_output(self) -> list | None:
145-
"""Standardize the output format."""
173+
def normalize_output(self) -> list[dict[str, Any]] | None:
174+
"""
175+
Convert computed privacy metrics into a standardized list of dictionaries.
176+
177+
Each dictionary contains:
178+
- 'metric': The metric identifier
179+
- 'val': The raw metric value
180+
181+
The metrics included are:
182+
- 'eps_identif_risk': The epsilon-identifiability risk of the real data
183+
- 'priv_loss_eps': The difference in epsilon risk between training and holdout
184+
data (only included if holdout data is provided)
185+
186+
If the evaluation has not been run yet (i.e., results are empty),
187+
the method returns None.
188+
189+
Returns:
190+
A list of metric dictionaries if results are available;
191+
otherwise, None.
192+
"""
146193
if self.results == {}:
147194
return None
148195

149196
output = [
150197
{
151198
"metric": "eps_identif_risk",
152-
"dim": "p",
153199
"val": self.results["eps_risk"],
154-
"n_val": 1 - self.results["eps_risk"],
155200
}
156201
]
157202

158203
if self.hout_data is not None:
159204
output.append(
160205
{
161206
"metric": "priv_loss_eps",
162-
"dim": "p",
163207
"val": self.results["priv_loss"],
164-
"n_val": 1 - abs(self.results["priv_loss"]),
165208
}
166209
)
167210

0 commit comments

Comments
 (0)