@@ -55,23 +55,34 @@ def _preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
5555
5656 for col , dtype in self .metadata .items ():
5757 if dtype == "categorical" :
58- # Use Label Encoding for small categories, OneHot for larger
59- encoder = LabelEncoder () if len (data [col ].unique ()) < 10 else OneHotEncoder (sparse = False , drop = "first" )
58+ # Choose encoder based on cardinality
59+ n_unique = len (data [col ].unique ())
60+ if n_unique < 10 :
61+ encoder = LabelEncoder ()
62+ elif n_unique < 50 :
63+ encoder = OneHotEncoder (sparse = False , drop = "first" )
64+ else :
65+ # Frequency encoding
66+ value_counts = data [col ].value_counts (normalize = True )
67+ encoder = {'type' : 'frequency' , 'mapping' : value_counts .to_dict ()}
68+
6069 transformed_data = self ._encode_categorical (data [col ], encoder )
6170 self .encoders [col ] = encoder
6271 data .drop (columns = [col ], inplace = True )
6372 data = pd .concat ([data , transformed_data ], axis = 1 )
6473
6574 elif dtype == "numerical" :
66- scaler = StandardScaler (with_mean = False , with_std = False )
75+ scaler = StandardScaler (with_mean = False , with_std = False )
6776 data [col ] = scaler .fit_transform (data [[col ]])
6877 self .scalers [col ] = scaler
6978
7079 elif dtype == "boolean" :
7180 data [col ] = data [col ].astype (int ) # Convert True/False to 1/0
7281
7382 elif dtype == "datetime" :
74- data [col ] = data [col ].apply (lambda x : x .timestamp () if pd .notnull (x ) else np .nan ) # Convert to Unix timestamp
83+ data [col ] = data [col ].apply (
84+ lambda x : x .timestamp () if pd .notnull (x ) else np .nan
85+ ) # Convert to Unix timestamp
7586
7687 elif dtype == "timedelta" :
7788 data [col ] = pd .to_timedelta (data [col ]).dt .total_seconds ()
@@ -125,11 +136,23 @@ def validate(self, data: pd.DataFrame):
125136 def _encode_categorical (self , series : pd .Series , encoder ):
126137 """Encode categorical columns."""
127138 if isinstance (encoder , LabelEncoder ):
128- return pd .DataFrame (encoder .fit_transform (series ), columns = [series .name ])
139+ return pd .DataFrame (
140+ encoder .fit_transform (series ),
141+ columns = [series .name ]
142+ )
129143 elif isinstance (encoder , OneHotEncoder ):
130144 encoded_array = encoder .fit_transform (series .values .reshape (- 1 , 1 ))
131- encoded_df = pd .DataFrame (encoded_array , columns = encoder .get_feature_names_out ([series .name ]))
145+ encoded_df = pd .DataFrame (
146+ encoded_array ,
147+ columns = encoder .get_feature_names_out ([series .name ])
148+ )
132149 return encoded_df
150+ elif isinstance (encoder , dict ) and encoder ['type' ] == 'frequency' :
151+ # Frequency encoding
152+ encoded_values = series .map (encoder ['mapping' ])
153+ return pd .DataFrame (encoded_values , columns = [series .name ])
154+ else :
155+ raise TypeError (f"Unsupported encoder type: { type (encoder )} " )
133156
134157 def _decode_categorical (self , encoded : pd .Series or pd .DataFrame , encoder ):
135158 """
@@ -170,6 +193,23 @@ def _decode_categorical(self, encoded: pd.Series or pd.DataFrame, encoder):
170193 cats = encoder .categories_ [0 ]
171194 return pd .Series (cats [idx ], index = getattr (encoded , "index" , None ))
172195
196+ # FREQUENCY ENCODER CASE
197+ elif isinstance (encoder , dict ) and encoder ['type' ] == 'frequency' :
198+ # For frequency encoding, we need to map the encoded values back to categories
199+ # We'll use the inverse mapping (frequency -> category)
200+ inverse_mapping = {v : k for k , v in encoder ['mapping' ].items ()}
201+ # Find the closest frequency for each encoded value
202+ encoded_values = encoded .values .flatten ()
203+ decoded = []
204+ for val in encoded_values :
205+ if pd .isna (val ):
206+ decoded .append (np .nan )
207+ else :
208+ # Find the category with the closest frequency
209+ closest_freq = min (inverse_mapping .keys (), key = lambda x : abs (x - val ))
210+ decoded .append (inverse_mapping [closest_freq ])
211+ return pd .Series (decoded , index = getattr (encoded , "index" , None ))
212+
173213 else :
174214 raise TypeError (f"Unsupported encoder type: { type (encoder )} " )
175215
0 commit comments