Skip to content

Commit 62343a0

Browse files
committed
bug fixed.
1 parent 9811b4d commit 62343a0

File tree

5 files changed

+49
-10
lines changed

5 files changed

+49
-10
lines changed

barcodebert/datasets.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,18 @@ def __init__(
181181
if self.return_taxonomy_level:
182182
taxonomy_column = f"{self.return_taxonomy_level}_name"
183183
if taxonomy_column in df.columns:
184-
self.taxonomy_labels, self.taxonomy_label_set = pd.factorize(df[taxonomy_column], sort=True)
184+
# Replace empty strings and NaN with 'UNKNOWN'
185+
taxonomy_col = df[taxonomy_column].replace("", "UNKNOWN").fillna("UNKNOWN")
186+
self.taxonomy_labels, self.taxonomy_label_set = pd.factorize(taxonomy_col, sort=True)
187+
# Map 'UNKNOWN' samples to -1 so they're excluded from pair creation
188+
unknown_mask = taxonomy_col == "UNKNOWN"
189+
num_unknown = unknown_mask.sum()
190+
self.taxonomy_labels = [
191+
-1 if is_unknown else label for label, is_unknown in zip(self.taxonomy_labels, unknown_mask)
192+
]
193+
print(f"Taxonomy labels: {len(self.taxonomy_labels)} total, {num_unknown} marked as UNKNOWN (-1)")
194+
print(f"Unique taxonomy categories: {self.taxonomy_label_set}")
195+
185196
else:
186197
print(f"Warning: Column '{taxonomy_column}' not found. Using dummy labels.")
187198
self.taxonomy_labels = [0] * len(self.labels)
@@ -195,7 +206,17 @@ def __init__(
195206
if self.return_taxonomy_level:
196207
taxonomy_column = f"{self.return_taxonomy_level}_name"
197208
if taxonomy_column in df.columns:
198-
self.taxonomy_labels, self.taxonomy_label_set = pd.factorize(df[taxonomy_column], sort=True)
209+
# Replace empty strings and NaN with 'UNKNOWN'
210+
taxonomy_col = df[taxonomy_column].replace("", "UNKNOWN").fillna("UNKNOWN")
211+
self.taxonomy_labels, self.taxonomy_label_set = pd.factorize(taxonomy_col, sort=True)
212+
# Map 'UNKNOWN' samples to -1 so they're excluded from pair creation
213+
unknown_mask = taxonomy_col == "UNKNOWN"
214+
num_unknown = unknown_mask.sum()
215+
self.taxonomy_labels = [
216+
-1 if is_unknown else label for label, is_unknown in zip(self.taxonomy_labels, unknown_mask)
217+
]
218+
print(f"Taxonomy labels: {len(self.taxonomy_labels)} total, {num_unknown} marked as UNKNOWN (-1)")
219+
print(f"Unique taxonomy categories: {self.taxonomy_label_set}")
199220
else:
200221
print(f"Warning: Column '{taxonomy_column}' not found. Using dummy labels.")
201222
self.taxonomy_labels = [0] * len(self.labels)

barcodebert/jumbo_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(self, embed_dim: int, jumbo_multiplier: int = 6, dropout: float = 0
2727

2828
self.jumbo_mlp = nn.Sequential(
2929
nn.LayerNorm(self.jumbo_width),
30-
nn.Linear(self.jumbo_width, self.jumbo_width * 2), # Wide hidden layer X4
30+
nn.Linear(self.jumbo_width, self.jumbo_width * 2), # Wide hidden layer X2
3131
nn.GELU(),
3232
nn.Dropout(dropout),
3333
nn.Linear(self.jumbo_width * 2, self.jumbo_width),

barcodebert/jumbo_transformer_with_taxonomy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,8 @@ def __init__(
3737
if self.enable_taxonomy_classification:
3838
jumbo_dim = bert_config.hidden_size * jumbo_multiplier
3939
self.taxonomy_classifier = JumboTaxonomyClassifier(jumbo_dim=jumbo_dim)
40-
# Alias for backward compatibility
41-
self.genus_classifier = self.taxonomy_classifier
4240
else:
4341
self.taxonomy_classifier = None
44-
self.genus_classifier = None
4542

4643
def forward(
4744
self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, **kwargs
@@ -88,6 +85,11 @@ def bert(self):
8885
"""Return the BERT model from the underlying transformer for compatibility."""
8986
return self.transformer.bert
9087

88+
@property
89+
def genus_classifier(self):
90+
"""Backward compatibility alias for taxonomy_classifier."""
91+
return self.taxonomy_classifier
92+
9193

9294
def create_jumbo_transformer_with_taxonomy(
9395
bert_config, jumbo_multiplier=6, share_jumbo_mlp_across_layers=False, enable_taxonomy_classification=True

barcodebert/maelm_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,8 @@ def __init__(
6464
jumbo_dim = decoder_config.hidden_size * jumbo_multiplier
6565

6666
self.taxonomy_classifier = JumboTaxonomyClassifier(jumbo_dim=jumbo_dim)
67-
# Keep genus_classifier as alias for backward compatibility
68-
self.genus_classifier = self.taxonomy_classifier
6967
else:
7068
self.taxonomy_classifier = None
71-
self.genus_classifier = None
7269

7370
def forward(self, input_ids, attention_mask, mask_positions, model_type="maelm_v2"):
7471
if model_type == "maelm_v2":
@@ -243,3 +240,8 @@ def forward_baseline(self, input_ids, attention_mask):
243240
)
244241

245242
return outputs
243+
244+
@property
245+
def genus_classifier(self):
246+
"""Backward compatibility alias for taxonomy_classifier."""
247+
return self.taxonomy_classifier

barcodebert/pretraining.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,14 @@ def print_pass(*args, **kwargs):
675675
print(f" Throughput .........{train_stats['throughput']:11.2f} samples/sec")
676676
print(f" Loss ...............{train_stats['loss']:14.5f}")
677677
print(f" Accuracy ...........{train_stats['accuracy']:11.2f} %")
678+
679+
# Print taxonomy classification metrics if available
680+
taxonomy_level_display = taxonomy_level.capitalize()
681+
if f"{taxonomy_level}_loss" in train_stats:
682+
print(f" {taxonomy_level_display} Loss .........{train_stats[f'{taxonomy_level}_loss']:14.5f}")
683+
print(f" {taxonomy_level_display} Accuracy .....{train_stats[f'{taxonomy_level}_accuracy'] * 100:11.2f} %")
684+
print(f" {taxonomy_level_display} Pairs ........{train_stats[f'{taxonomy_level}_pairs']:8d}")
685+
678686
print(flush=True)
679687

680688
# Validate ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -716,6 +724,12 @@ def print_pass(*args, **kwargs):
716724
print(f" Loss ...............{eval_stats['loss']:14.5f}")
717725
print(f" Accuracy ...........{eval_stats['accuracy']:11.2f} %")
718726

727+
# Print taxonomy classification metrics if available
728+
if f"{taxonomy_level}_loss" in eval_stats:
729+
print(f" {taxonomy_level_display} Loss .........{eval_stats[f'{taxonomy_level}_loss']:14.5f}")
730+
print(f" {taxonomy_level_display} Accuracy .....{eval_stats[f'{taxonomy_level}_accuracy'] * 100:11.2f} %")
731+
print(f" {taxonomy_level_display} Pairs ........{eval_stats[f'{taxonomy_level}_pairs']:8d}")
732+
719733
# Save model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
720734
t_start_save = time.time()
721735

@@ -2092,7 +2106,7 @@ def get_parser():
20922106
"--jumbo_source",
20932107
dest="jumbo_source",
20942108
type=str,
2095-
default="encoder",
2109+
default="decoder",
20962110
choices=["encoder", "decoder"],
20972111
help=(
20982112
"Source of jumbo tokens for taxonomy classification: 'encoder' (direct from encoder) or "

0 commit comments

Comments
 (0)