11from collections .abc import Iterable
2+ from typing import Any , Literal
23
34import numpy as np
45import pandas as pd
1011
1112def _column_entropy (labels : list | np .ndarray ) -> np .number :
1213 """Compute the entropy of a single column."""
13- value , counts = np .unique (np .round (labels ), return_counts = True )
14+ _ , counts = np .unique (np .round (labels ), return_counts = True )
1415 return entropy (counts )
1516
1617
1718def batched_reference_knn (
1819 query_df : pd .DataFrame ,
1920 reference_df : pd .DataFrame ,
20- cat_cols : list [int ],
21- nn_dist : str ,
21+ categorical_columns : list [int ],
22+ nn_distance_metric : Literal [ "gower" , "euclid" ] ,
2223 weights : np .ndarray ,
2324 ref_batch_size : int = 128 ,
2425 show_progress : bool = True ,
2526) -> np .ndarray :
2627 """
27- Compute k- nearest neighbor distances from query rows to reference rows in a memory-efficient way.
28+ Compute nearest neighbor distances from the points in query_df to reference_df in a memory-efficient way.
2829
2930 Instead of comparing all query rows to all reference rows at once, the reference DataFrame
3031 is split into batches. For each batch:
31- 1. Compute the distances from all query rows to the current batch.
32+ 1. Compute the distances from all query rows to the current reference_df batch.
3233 2. Keep track of the smallest distance per query row across all batches.
3334
3435 Args:
35- query_df : The data points for which kNN distances are computed.
36+ query_df : The data points for which nearest neighbor distances are computed.
3637 reference_df : The data points used as the reference for computing distances.
37- cat_cols : Indices of categorical columns.
38- nn_dist : Distance metric to use for nearest neighbor computation.
38+ categorical_columns : Indices of categorical columns.
39+ nn_distance_metric : Distance metric to use for nearest neighbor distance computation. Possible values are the
40+ Gower distance metric ('gower') and the Euclidean distance metric ('euclid').
3941 weights : Feature weights to apply when computing distances.
4042 ref_batch_size : Number of reference rows per batch.
4143 show_progress : Whether to display a progress bar over reference batches.
4244
43- Returns :
44- Array of minimum distances per query row after considering all reference batches.
45+ Returns:
46+ Array of nearest neighbor distance per query row after considering all reference batches.
4547 """
4648 n_query = len (query_df )
4749
48- # best distances so far = +inf
49- best_d = np .full (n_query , np .inf , dtype = float )
50+ # Initizalizing a list of best distances with np.inf so they can be replaced with the actual best distances later.
51+ nearest_neighbor_distance = np .full (n_query , np .inf , dtype = float )
5052
5153 iterator : Iterable [int ]
5254 if show_progress :
5355 iterator = tqdm (
5456 range (0 , len (reference_df ), ref_batch_size ),
5557 total = (len (reference_df ) + ref_batch_size - 1 ) // ref_batch_size ,
56- desc = "Computing ref-batched kNN distances" ,
58+ desc = "Computing nearest neighbor distances from real/holdout dataset to synthetic dataset. " ,
5759 )
5860 else :
5961 iterator = range (0 , len (reference_df ), ref_batch_size )
@@ -62,35 +64,62 @@ def batched_reference_knn(
6264 end = min (start + ref_batch_size , len (reference_df ))
6365 ref_batch = reference_df .iloc [start :end ]
6466
65- # compute distances to this reference batch (k=1 → index 0)
66- d_batch = _knn_distance (query_df , ref_batch , cat_cols , 1 , nn_dist , weights )[0 ]
67+ # compute distances for each row of the reference batch to its closest neigbour in ref_batch
68+ # hardcoding of k=1 refers to only needing to compute the distance to the closest neighbor.
69+ batch_distances = _knn_distance (query_df , ref_batch , categorical_columns , 1 , nn_distance_metric , weights )[0 ]
6770
6871 # keep smallest per query row
69- best_d = np .minimum (best_d , d_batch )
72+ nearest_neighbor_distance = np .minimum (nearest_neighbor_distance , batch_distances )
7073
71- return best_d
74+ return nearest_neighbor_distance
7275
7376
7477class EpsilonIdentifiability (MetricClass ): # type: ignore[misc]
7578 def name (self ) -> str :
76- """Return the name of the metric."""
79+ """
80+ Returns the identifier of the metric.
81+
82+ Returns:
83+ "eps_risk"
84+ """
7785 return "eps_risk"
7886
7987 def type (self ) -> str :
80- """Return the type of the metric."""
88+ """
89+ Returns the type of the evaluation metric.
90+
91+ Returns:
92+ "privacy"
93+ """
8194 return "privacy"
8295
8396 def evaluate (self ) -> dict :
84- """Compute the Epsilon Identifiability Risk and Privacy Loss."""
85- real = np .asarray (self .real_data )
86- no , x_dim = real .shape
97+ """
98+ Compute epsilon-identifiability risk and privacy loss.
99+
100+ The epsilon-identifiability risk (eps_risk) is defined as the fraction of real
101+ records whose nearest neighbor in the synthetic dataset is closer than their
102+ nearest neighbor in the real dataset, using an entropy-weighted distance metric.
103+
104+ If holdout data is provided, the privacy loss (priv_loss) is computed as the
105+ difference between the identifiability risk on the training data and the
106+ identifiability risk on the holdout data.
107+
108+ Returns:
109+ dict:
110+ - 'eps_risk': Fraction of real records vulnerable to re-identification.
111+ - 'priv_loss': Difference between training and holdout identifiability risks
112+ (only present if holdout data is not None).
113+ """
114+ np_real_data = np .asarray (self .real_data )
115+ real_size , n_feautures = np_real_data .shape
87116
88117 # Column entropies → weights (inverted)
89- weights = [_column_entropy (real [:, i ]) for i in range (x_dim )]
118+ weights = [_column_entropy (np_real_data [:, feauture ]) for feauture in range (n_feautures )]
90119 weights_adjusted = 1 / (np .array (weights ) + 1e-16 )
91120
92121 # INTERNAL KNN: REAL → REAL
93- in_dists = _knn_distance (
122+ internal_distances = _knn_distance (
94123 self .real_data ,
95124 self .real_data ,
96125 self .cat_cols ,
@@ -100,37 +129,37 @@ def evaluate(self) -> dict:
100129 )[0 ]
101130
102131 # EXTERNAL KNN: REAL → SYNTHETIC (safe to batch reference)
103- ext_dists = batched_reference_knn (
132+ external_distances = batched_reference_knn (
104133 self .real_data ,
105134 self .synt_data ,
106135 self .cat_cols ,
107136 self .nn_dist ,
108137 weights_adjusted ,
109138 )
110139
111- r_diff = ext_dists - in_dists
112- identifiability = np .sum (r_diff < 0 ) / float (no )
113- self .results ["eps_risk" ] = identifiability
140+ real_data_distance_differences = external_distances - internal_distances
141+ identifiability_risk = np .sum (real_data_distance_differences < 0 ) / float (real_size )
142+ self .results ["eps_risk" ] = identifiability_risk
114143
115144 if self .hout_data is not None :
116145 # INTERNAL: HOUT → HOUT (original logic)
117- hout_in = _knn_distance (self . hout_data , self . hout_data , self . cat_cols , 1 , self . nn_dist , weights_adjusted )[
118- 0
119- ]
146+ hout_internal_distances = _knn_distance (
147+ self . hout_data , self . hout_data , self . cat_cols , 1 , self . nn_dist , weights_adjusted
148+ )[ 0 ]
120149
121150 # EXTERNAL: HOUT → SYNTHETIC (batched)
122- hout_ext = batched_reference_knn (
151+ hout_external_distances = batched_reference_knn (
123152 self .hout_data ,
124153 self .synt_data ,
125154 self .cat_cols ,
126155 self .nn_dist ,
127156 weights_adjusted ,
128157 )
129158
130- hout_diff = hout_ext - hout_in
131- hout_val = np .sum (hout_diff < 0 ) / float (len (self .hout_data ))
159+ holdout_data_distance_differences = hout_external_distances - hout_internal_distances
160+ hout_identifiability_risk = np .sum (holdout_data_distance_differences < 0 ) / float (len (self .hout_data ))
132161
133- self .results ["priv_loss" ] = self .results ["eps_risk" ] - hout_val
162+ self .results ["priv_loss" ] = self .results ["eps_risk" ] - hout_identifiability_risk
134163
135164 return self .results
136165
@@ -141,27 +170,41 @@ def format_output(self) -> str:
141170 string += f"\n | Privacy loss (diff. in eps. risk) : { self .results ['priv_loss' ]:.4f} |"
142171 return string
143172
144- def normalize_output (self ) -> list | None :
145- """Standardize the output format."""
173+ def normalize_output (self ) -> list [dict [str , Any ]] | None :
174+ """
175+ Convert computed privacy metrics into a standardized list of dictionaries.
176+
177+ Each dictionary contains:
178+ - 'metric': The metric identifier
179+ - 'val': The raw metric value
180+
181+ The metrics included are:
182+ - 'eps_identif_risk': The epsilon-identifiability risk of the real data
183+ - 'priv_loss_eps': The difference in epsilon risk between training and holdout
184+ data (only included if holdout data is provided)
185+
186+ If the evaluation has not been run yet (i.e., results are empty),
187+ the method returns None.
188+
189+ Returns:
190+ A list of metric dictionaries if results are available;
191+ otherwise, None.
192+ """
146193 if self .results == {}:
147194 return None
148195
149196 output = [
150197 {
151198 "metric" : "eps_identif_risk" ,
152- "dim" : "p" ,
153199 "val" : self .results ["eps_risk" ],
154- "n_val" : 1 - self .results ["eps_risk" ],
155200 }
156201 ]
157202
158203 if self .hout_data is not None :
159204 output .append (
160205 {
161206 "metric" : "priv_loss_eps" ,
162- "dim" : "p" ,
163207 "val" : self .results ["priv_loss" ],
164- "n_val" : 1 - abs (self .results ["priv_loss" ]),
165208 }
166209 )
167210
0 commit comments