@@ -32,6 +32,7 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True,
3232 self .encoders = {} # Stores encoders for categorical columns
3333 self .scalers = {} # Stores scalers for numerical columns
3434 self .original_columns = None # To restore column order
35+ self ._original_dtypes = None # Store original dtypes
3536
3637 def preprocess (self , data : pd .DataFrame ) -> pd .DataFrame :
3738 """Transform the raw data into numerical space."""
@@ -43,6 +44,7 @@ def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
4344
4445 self .validate (data )
4546 self .original_columns = data .columns # Store original column order
47+ self ._original_dtypes = data .dtypes # Store original dtypes
4648 processed_data = self ._preprocess (data )
4749
4850 return processed_data
@@ -88,6 +90,12 @@ def postprocess(self, synthetic_data: pd.DataFrame) -> pd.DataFrame:
8890 elif dtype == "numerical" and col in self .scalers :
8991 scaler = self .scalers [col ]
9092 synthetic_data [col ] = scaler .inverse_transform (synthetic_data [[col ]])
93+
94+ # Restore original dtype for numerical columns
95+ if self ._original_dtypes is not None :
96+ original_dtype = self ._original_dtypes [col ]
97+ if np .issubdtype (original_dtype , np .integer ):
98+ synthetic_data [col ] = synthetic_data [col ].round ().astype (original_dtype )
9199
92100 elif dtype == "boolean" :
93101 synthetic_data [col ] = synthetic_data [col ].round ().astype (bool )
@@ -123,13 +131,47 @@ def _encode_categorical(self, series: pd.Series, encoder):
123131 encoded_df = pd .DataFrame (encoded_array , columns = encoder .get_feature_names_out ([series .name ]))
124132 return encoded_df
125133
126- def _decode_categorical (self , series : pd .Series , encoder ):
127- """Decode categorical columns."""
134+ def _decode_categorical (self , encoded : pd .Series or pd .DataFrame , encoder ):
135+ """
136+ Decode categorical columns, snapping any out‐of‐range codes back to the nearest
137+ valid category (or to NaN), so novel copula values won't blow up.
138+ """
139+ # LABEL ENCODER CASE
128140 if isinstance (encoder , LabelEncoder ):
129- return encoder .inverse_transform (series .astype (int ))
141+ # Pull out the raw numeric codes (may be floats from copula)
142+ codes = np .rint (encoded .astype (float )).astype (int )
143+ max_idx = len (encoder .classes_ ) - 1
144+
145+ # Any code outside [0, max_idx] → -1 sentinel
146+ safe_codes = np .where ((codes >= 0 ) & (codes <= max_idx ), codes , - 1 )
147+
148+ # Map valid codes back to labels, sentinel→NaN
149+ decoded = [
150+ encoder .classes_ [c ] if c >= 0 else np .nan
151+ for c in safe_codes
152+ ]
153+ return pd .Series (decoded , index = getattr (encoded , "index" , None ))
154+
155+ # ONE-HOT ENCODER CASE
130156 elif isinstance (encoder , OneHotEncoder ):
131- category_index = np .argmax (series .values , axis = 1 )
132- return encoder .categories_ [0 ][category_index ]
157+ # Ensure a 2D array of one-hot "scores"
158+ arr = encoded .values if isinstance (encoded , pd .DataFrame ) else np .asarray (encoded )
159+ if arr .ndim == 1 :
160+ # If someone passed a flat Series, assume the first category axis:
161+ n_cat = len (encoder .categories_ [0 ])
162+ arr = arr .reshape (- 1 , n_cat )
163+
164+ # Argmax and clip into [0, n_cat-1]
165+ idx = np .argmax (arr , axis = 1 )
166+ max_idx = len (encoder .categories_ [0 ]) - 1
167+ idx = np .clip (idx , 0 , max_idx )
168+
169+ # Look up the category labels
170+ cats = encoder .categories_ [0 ]
171+ return pd .Series (cats [idx ], index = getattr (encoded , "index" , None ))
172+
173+ else :
174+ raise TypeError (f"Unsupported encoder type: { type (encoder )} " )
133175
134176 def _handle_missing_values (self , series : pd .Series ):
135177 """Handle missing values based on column type."""
0 commit comments