Skip to content

Commit afed612

Browse files
committed
loading embeddings is fixed.
1 parent ec15ed9 commit afed612

File tree

2 files changed

+85
-17
lines changed

2 files changed

+85
-17
lines changed

barcodebert/datasets.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,42 @@ def __getitem__(self, idx):
290290
return processed_barcode, label, att_mask
291291

292292

293-
def representations_from_df(df, target_level, model, tokenizer, dataset_name, mode=None, mask_rate=None):
294-
293+
def representations_from_df(
294+
df, target_level, model, tokenizer, dataset_name, mode=None, mask_rate=None, representation_type="tokens"
295+
):
296+
"""
297+
Extract representations from DNA sequences in a dataframe.
298+
299+
Parameters
300+
----------
301+
df : pd.DataFrame
302+
Dataframe containing DNA sequences
303+
target_level : str
304+
Taxonomic level to use as labels
305+
model : torch.nn.Module
306+
Pretrained model
307+
tokenizer : Tokenizer
308+
Tokenizer for DNA sequences
309+
dataset_name : str
310+
Dataset name (CANADA-1.5M or BIOSCAN-5M)
311+
mode : str, optional
312+
Mode (not currently used)
313+
mask_rate : float, optional
314+
Mask rate (not currently used)
315+
representation_type : str, optional
316+
Type of representation to extract:
317+
- "tokens": Mean pooling of sequence tokens (default, backward compatible)
318+
- "jumbo": Jumbo representation from jumbo CLS tokens (if model has jumbo tokens)
319+
320+
Returns
321+
-------
322+
latent : np.ndarray
323+
Latent representations
324+
y : np.ndarray
325+
Labels
326+
orders : np.ndarray
327+
Order names
328+
"""
295329
orders = df["order_name"].to_numpy()
296330
if dataset_name == "CANADA-1.5M":
297331
_label_set, y = np.unique(df[target_level], return_inverse=True)
@@ -309,25 +343,45 @@ def representations_from_df(df, target_level, model, tokenizer, dataset_name, mo
309343

310344
x = x.unsqueeze(0).to(model.device)
311345
att_mask = att_mask.unsqueeze(0).to(model.device)
312-
x = model(x, att_mask).hidden_states[-1]
313-
# previous mean pooling
314-
# x = x.mean(1)
315-
# dna_embeddings.append(x.cpu().numpy())
316346

317-
# updated mean pooling to account for the attention mask and padding tokens
318-
# sum the embeddings of the tokens (excluding padding tokens)
319-
sum_embeddings = (x * att_mask.unsqueeze(-1)).sum(1) # (batch_size, hidden_size)
320-
# sum the attention mask (number of tokens in the sequence without considering the padding tokens)
321-
sum_mask = att_mask.sum(1, keepdim=True)
322-
# calculate the mean embeddings
323-
mean_embeddings = sum_embeddings / sum_mask # (batch_size, hidden_size)
347+
# Get model output
348+
output = model(x, att_mask)
349+
350+
# Extract representation based on type
351+
if representation_type == "jumbo":
352+
# Use jumbo representation if available
353+
if hasattr(output, "jumbo_representation"):
354+
embedding = output.jumbo_representation # (batch_size, J*D)
355+
else:
356+
raise ValueError(
357+
"Model does not have jumbo_representation. "
358+
"Use representation_type='tokens' or use a Jumbo transformer model."
359+
)
360+
elif representation_type == "tokens":
361+
# Use mean pooling of sequence tokens (default behavior)
362+
if hasattr(output, "hidden_states"):
363+
hidden_states = output.hidden_states
364+
else:
365+
# Fallback for models that return hidden states directly
366+
hidden_states = output[-1] if isinstance(output, tuple) else output
367+
368+
# Mean pooling accounting for attention mask and padding tokens
369+
# Sum the embeddings of the tokens (excluding padding tokens)
370+
sum_embeddings = (hidden_states * att_mask.unsqueeze(-1)).sum(1) # (batch_size, hidden_size)
371+
# Sum the attention mask (number of tokens without padding)
372+
sum_mask = att_mask.sum(1, keepdim=True)
373+
# Calculate the mean embeddings
374+
embedding = sum_embeddings / sum_mask # (batch_size, hidden_size)
375+
else:
376+
raise ValueError(f"Invalid representation_type: {representation_type}. Must be 'tokens' or 'jumbo'.")
324377

325-
dna_embeddings.append(mean_embeddings.cpu().numpy())
378+
dna_embeddings.append(embedding.cpu().numpy())
326379

327380
print(f"There are {len(df)} points in the dataset")
381+
print(f"Using representation type: {representation_type}")
328382
latent = np.array(dna_embeddings)
329383
latent = np.squeeze(latent, 1)
330-
print(latent.shape)
384+
print(f"Representation shape: {latent.shape}")
331385
return latent, y, orders
332386

333387

barcodebert/knn_probing.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,14 @@ def run(config):
145145
# Generate emebddings for the training and test sets
146146
print("Generating embeddings for test set", flush=True)
147147
X_unseen, y_unseen, orders = representations_from_df(
148-
df_test, config.target_level, model, tokenizer, config.dataset_name, config.mode, config.mask_rate, config.jumbo
148+
df_test,
149+
config.target_level,
150+
model,
151+
tokenizer,
152+
config.dataset_name,
153+
config.mode,
154+
config.mask_rate,
155+
config.representation_type,
149156
)
150157
print("Generating embeddings for train set", flush=True)
151158
X, y, train_orders = representations_from_df(
@@ -156,7 +163,7 @@ def run(config):
156163
config.dataset_name,
157164
config.mode,
158165
config.mask_rate,
159-
config.jumbo,
166+
config.representation_type,
160167
)
161168
timing_stats["embed"] = time.time() - t_start_embed
162169

@@ -331,6 +338,13 @@ def get_parser():
331338
help="Mask rate for masked language model. Default: %(default)s",
332339
)
333340

341+
group.add_argument(
342+
"--representation_type",
343+
default="tokens",
344+
type=str,
345+
help="Type of representation to use. Default: %(default)s",
346+
)
347+
334348
return parser
335349

336350

0 commit comments

Comments
 (0)