Skip to content

Commit 70e6d01

Browse files
committed
Update data_processor.py
Modify post_process function to deal with out of bounds generated values.
1 parent b8a278a commit 70e6d01

File tree

1 file changed

+47
-5
lines changed

1 file changed

+47
-5
lines changed

synthpop/processor/data_processor.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True,
3232
self.encoders = {} # Stores encoders for categorical columns
3333
self.scalers = {} # Stores scalers for numerical columns
3434
self.original_columns = None # To restore column order
35+
self._original_dtypes = None # Store original dtypes
3536

3637
def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
3738
"""Transform the raw data into numerical space."""
@@ -43,6 +44,7 @@ def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
4344

4445
self.validate(data)
4546
self.original_columns = data.columns # Store original column order
47+
self._original_dtypes = data.dtypes # Store original dtypes
4648
processed_data = self._preprocess(data)
4749

4850
return processed_data
@@ -88,6 +90,12 @@ def postprocess(self, synthetic_data: pd.DataFrame) -> pd.DataFrame:
8890
elif dtype == "numerical" and col in self.scalers:
8991
scaler = self.scalers[col]
9092
synthetic_data[col] = scaler.inverse_transform(synthetic_data[[col]])
93+
94+
# Restore original dtype for numerical columns
95+
if self._original_dtypes is not None:
96+
original_dtype = self._original_dtypes[col]
97+
if np.issubdtype(original_dtype, np.integer):
98+
synthetic_data[col] = synthetic_data[col].round().astype(original_dtype)
9199

92100
elif dtype == "boolean":
93101
synthetic_data[col] = synthetic_data[col].round().astype(bool)
@@ -123,13 +131,47 @@ def _encode_categorical(self, series: pd.Series, encoder):
123131
encoded_df = pd.DataFrame(encoded_array, columns=encoder.get_feature_names_out([series.name]))
124132
return encoded_df
125133

126-
def _decode_categorical(self, series: pd.Series, encoder):
127-
"""Decode categorical columns."""
134+
def _decode_categorical(self, encoded: pd.Series or pd.DataFrame, encoder):
135+
"""
136+
Decode categorical columns, snapping any out‐of‐range codes back to the nearest
137+
valid category (or to NaN), so novel copula values won't blow up.
138+
"""
139+
# LABEL ENCODER CASE
128140
if isinstance(encoder, LabelEncoder):
129-
return encoder.inverse_transform(series.astype(int))
141+
# Pull out the raw numeric codes (may be floats from copula)
142+
codes = np.rint(encoded.astype(float)).astype(int)
143+
max_idx = len(encoder.classes_) - 1
144+
145+
# Any code outside [0, max_idx] → -1 sentinel
146+
safe_codes = np.where((codes >= 0) & (codes <= max_idx), codes, -1)
147+
148+
# Map valid codes back to labels, sentinel→NaN
149+
decoded = [
150+
encoder.classes_[c] if c >= 0 else np.nan
151+
for c in safe_codes
152+
]
153+
return pd.Series(decoded, index=getattr(encoded, "index", None))
154+
155+
# ONE-HOT ENCODER CASE
130156
elif isinstance(encoder, OneHotEncoder):
131-
category_index = np.argmax(series.values, axis=1)
132-
return encoder.categories_[0][category_index]
157+
# Ensure a 2D array of one-hot "scores"
158+
arr = encoded.values if isinstance(encoded, pd.DataFrame) else np.asarray(encoded)
159+
if arr.ndim == 1:
160+
# If someone passed a flat Series, assume the first category axis:
161+
n_cat = len(encoder.categories_[0])
162+
arr = arr.reshape(-1, n_cat)
163+
164+
# Argmax and clip into [0, n_cat-1]
165+
idx = np.argmax(arr, axis=1)
166+
max_idx = len(encoder.categories_[0]) - 1
167+
idx = np.clip(idx, 0, max_idx)
168+
169+
# Look up the category labels
170+
cats = encoder.categories_[0]
171+
return pd.Series(cats[idx], index=getattr(encoded, "index", None))
172+
173+
else:
174+
raise TypeError(f"Unsupported encoder type: {type(encoder)}")
133175

134176
def _handle_missing_values(self, series: pd.Series):
135177
"""Handle missing values based on column type."""

0 commit comments

Comments
 (0)