Skip to content

Commit 26cb1c0

Browse files
committed
Update data_processor.py
Update processing step to deal with high cardinality categorical columns. Now use frequency encoding if more than 50 categories.
1 parent 9e28ed6 commit 26cb1c0

File tree

1 file changed

+46
-6
lines changed

1 file changed

+46
-6
lines changed

synthpop/processor/data_processor.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,34 @@ def _preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
5555

5656
for col, dtype in self.metadata.items():
5757
if dtype == "categorical":
58-
# Use Label Encoding for small categories, OneHot for larger
59-
encoder = LabelEncoder() if len(data[col].unique()) < 10 else OneHotEncoder(sparse=False, drop="first")
58+
# Choose encoder based on cardinality
59+
n_unique = len(data[col].unique())
60+
if n_unique < 10:
61+
encoder = LabelEncoder()
62+
elif n_unique < 50:
63+
encoder = OneHotEncoder(sparse=False, drop="first")
64+
else:
65+
# Frequency encoding
66+
value_counts = data[col].value_counts(normalize=True)
67+
encoder = {'type': 'frequency', 'mapping': value_counts.to_dict()}
68+
6069
transformed_data = self._encode_categorical(data[col], encoder)
6170
self.encoders[col] = encoder
6271
data.drop(columns=[col], inplace=True)
6372
data = pd.concat([data, transformed_data], axis=1)
6473

6574
elif dtype == "numerical":
66-
scaler = StandardScaler(with_mean= False, with_std= False)
75+
scaler = StandardScaler(with_mean=False, with_std=False)
6776
data[col] = scaler.fit_transform(data[[col]])
6877
self.scalers[col] = scaler
6978

7079
elif dtype == "boolean":
7180
data[col] = data[col].astype(int) # Convert True/False to 1/0
7281

7382
elif dtype == "datetime":
74-
data[col] = data[col].apply(lambda x: x.timestamp() if pd.notnull(x) else np.nan) # Convert to Unix timestamp
83+
data[col] = data[col].apply(
84+
lambda x: x.timestamp() if pd.notnull(x) else np.nan
85+
) # Convert to Unix timestamp
7586

7687
elif dtype == "timedelta":
7788
data[col] = pd.to_timedelta(data[col]).dt.total_seconds()
@@ -125,11 +136,23 @@ def validate(self, data: pd.DataFrame):
125136
def _encode_categorical(self, series: pd.Series, encoder):
126137
"""Encode categorical columns."""
127138
if isinstance(encoder, LabelEncoder):
128-
return pd.DataFrame(encoder.fit_transform(series), columns=[series.name])
139+
return pd.DataFrame(
140+
encoder.fit_transform(series),
141+
columns=[series.name]
142+
)
129143
elif isinstance(encoder, OneHotEncoder):
130144
encoded_array = encoder.fit_transform(series.values.reshape(-1, 1))
131-
encoded_df = pd.DataFrame(encoded_array, columns=encoder.get_feature_names_out([series.name]))
145+
encoded_df = pd.DataFrame(
146+
encoded_array,
147+
columns=encoder.get_feature_names_out([series.name])
148+
)
132149
return encoded_df
150+
elif isinstance(encoder, dict) and encoder['type'] == 'frequency':
151+
# Frequency encoding
152+
encoded_values = series.map(encoder['mapping'])
153+
return pd.DataFrame(encoded_values, columns=[series.name])
154+
else:
155+
raise TypeError(f"Unsupported encoder type: {type(encoder)}")
133156

134157
def _decode_categorical(self, encoded: pd.Series or pd.DataFrame, encoder):
135158
"""
@@ -170,6 +193,23 @@ def _decode_categorical(self, encoded: pd.Series or pd.DataFrame, encoder):
170193
cats = encoder.categories_[0]
171194
return pd.Series(cats[idx], index=getattr(encoded, "index", None))
172195

196+
# FREQUENCY ENCODER CASE
197+
elif isinstance(encoder, dict) and encoder['type'] == 'frequency':
198+
# For frequency encoding, we need to map the encoded values back to categories
199+
# We'll use the inverse mapping (frequency -> category)
200+
inverse_mapping = {v: k for k, v in encoder['mapping'].items()}
201+
# Find the closest frequency for each encoded value
202+
encoded_values = encoded.values.flatten()
203+
decoded = []
204+
for val in encoded_values:
205+
if pd.isna(val):
206+
decoded.append(np.nan)
207+
else:
208+
# Find the category with the closest frequency
209+
closest_freq = min(inverse_mapping.keys(), key=lambda x: abs(x - val))
210+
decoded.append(inverse_mapping[closest_freq])
211+
return pd.Series(decoded, index=getattr(encoded, "index", None))
212+
173213
else:
174214
raise TypeError(f"Unsupported encoder type: {type(encoder)}")
175215

0 commit comments

Comments
 (0)