@@ -181,7 +181,18 @@ def __init__(
181181 if self .return_taxonomy_level :
182182 taxonomy_column = f"{ self .return_taxonomy_level } _name"
183183 if taxonomy_column in df .columns :
184- self .taxonomy_labels , self .taxonomy_label_set = pd .factorize (df [taxonomy_column ], sort = True )
184+ # Replace empty strings and NaN with 'UNKNOWN'
185+ taxonomy_col = df [taxonomy_column ].replace ("" , "UNKNOWN" ).fillna ("UNKNOWN" )
186+ self .taxonomy_labels , self .taxonomy_label_set = pd .factorize (taxonomy_col , sort = True )
187+ # Map 'UNKNOWN' samples to -1 so they're excluded from pair creation
188+ unknown_mask = taxonomy_col == "UNKNOWN"
189+ num_unknown = unknown_mask .sum ()
190+ self .taxonomy_labels = [
191+ - 1 if is_unknown else label for label , is_unknown in zip (self .taxonomy_labels , unknown_mask )
192+ ]
193+ print (f"Taxonomy labels: { len (self .taxonomy_labels )} total, { num_unknown } marked as UNKNOWN (-1)" )
194+ print (f"Unique taxonomy categories: { self .taxonomy_label_set } " )
195+
185196 else :
186197 print (f"Warning: Column '{ taxonomy_column } ' not found. Using dummy labels." )
187198 self .taxonomy_labels = [0 ] * len (self .labels )
@@ -195,7 +206,17 @@ def __init__(
195206 if self .return_taxonomy_level :
196207 taxonomy_column = f"{ self .return_taxonomy_level } _name"
197208 if taxonomy_column in df .columns :
198- self .taxonomy_labels , self .taxonomy_label_set = pd .factorize (df [taxonomy_column ], sort = True )
209+ # Replace empty strings and NaN with 'UNKNOWN'
210+ taxonomy_col = df [taxonomy_column ].replace ("" , "UNKNOWN" ).fillna ("UNKNOWN" )
211+ self .taxonomy_labels , self .taxonomy_label_set = pd .factorize (taxonomy_col , sort = True )
212+ # Map 'UNKNOWN' samples to -1 so they're excluded from pair creation
213+ unknown_mask = taxonomy_col == "UNKNOWN"
214+ num_unknown = unknown_mask .sum ()
215+ self .taxonomy_labels = [
216+ - 1 if is_unknown else label for label , is_unknown in zip (self .taxonomy_labels , unknown_mask )
217+ ]
218+ print (f"Taxonomy labels: { len (self .taxonomy_labels )} total, { num_unknown } marked as UNKNOWN (-1)" )
219+ print (f"Unique taxonomy categories: { self .taxonomy_label_set } " )
199220 else :
200221 print (f"Warning: Column '{ taxonomy_column } ' not found. Using dummy labels." )
201222 self .taxonomy_labels = [0 ] * len (self .labels )
0 commit comments