@@ -290,8 +290,42 @@ def __getitem__(self, idx):
290290 return processed_barcode , label , att_mask
291291
292292
293- def representations_from_df (df , target_level , model , tokenizer , dataset_name , mode = None , mask_rate = None ):
294-
293+ def representations_from_df (
294+ df , target_level , model , tokenizer , dataset_name , mode = None , mask_rate = None , representation_type = "tokens"
295+ ):
296+ """
297+ Extract representations from DNA sequences in a dataframe.
298+
299+ Parameters
300+ ----------
301+ df : pd.DataFrame
302+ Dataframe containing DNA sequences
303+ target_level : str
304+ Taxonomic level to use as labels
305+ model : torch.nn.Module
306+ Pretrained model
307+ tokenizer : Tokenizer
308+ Tokenizer for DNA sequences
309+ dataset_name : str
310+ Dataset name (CANADA-1.5M or BIOSCAN-5M)
311+ mode : str, optional
312+ Mode (not currently used)
313+ mask_rate : float, optional
314+ Mask rate (not currently used)
315+ representation_type : str, optional
316+ Type of representation to extract:
317+ - "tokens": Mean pooling of sequence tokens (default, backward compatible)
318+ - "jumbo": Jumbo representation from jumbo CLS tokens (if model has jumbo tokens)
319+
320+ Returns
321+ -------
322+ latent : np.ndarray
323+ Latent representations
324+ y : np.ndarray
325+ Labels
326+ orders : np.ndarray
327+ Order names
328+ """
295329 orders = df ["order_name" ].to_numpy ()
296330 if dataset_name == "CANADA-1.5M" :
297331 _label_set , y = np .unique (df [target_level ], return_inverse = True )
@@ -309,25 +343,45 @@ def representations_from_df(df, target_level, model, tokenizer, dataset_name, mo
309343
310344 x = x .unsqueeze (0 ).to (model .device )
311345 att_mask = att_mask .unsqueeze (0 ).to (model .device )
312- x = model (x , att_mask ).hidden_states [- 1 ]
313- # previous mean pooling
314- # x = x.mean(1)
315- # dna_embeddings.append(x.cpu().numpy())
316346
317- # updated mean pooling to account for the attention mask and padding tokens
318- # sum the embeddings of the tokens (excluding padding tokens)
319- sum_embeddings = (x * att_mask .unsqueeze (- 1 )).sum (1 ) # (batch_size, hidden_size)
320- # sum the attention mask (number of tokens in the sequence without considering the padding tokens)
321- sum_mask = att_mask .sum (1 , keepdim = True )
322- # calculate the mean embeddings
323- mean_embeddings = sum_embeddings / sum_mask # (batch_size, hidden_size)
347+ # Get model output
348+ output = model (x , att_mask )
349+
350+ # Extract representation based on type
351+ if representation_type == "jumbo" :
352+ # Use jumbo representation if available
353+ if hasattr (output , "jumbo_representation" ):
354+ embedding = output .jumbo_representation # (batch_size, J*D)
355+ else :
356+ raise ValueError (
357+ "Model does not have jumbo_representation. "
358+ "Use representation_type='tokens' or use a Jumbo transformer model."
359+ )
360+ elif representation_type == "tokens" :
361+ # Use mean pooling of sequence tokens (default behavior)
362+ if hasattr (output , "hidden_states" ):
363+ hidden_states = output .hidden_states
364+ else :
365+ # Fallback for models that return hidden states directly
366+ hidden_states = output [- 1 ] if isinstance (output , tuple ) else output
367+
368+ # Mean pooling accounting for attention mask and padding tokens
369+ # Sum the embeddings of the tokens (excluding padding tokens)
370+ sum_embeddings = (hidden_states * att_mask .unsqueeze (- 1 )).sum (1 ) # (batch_size, hidden_size)
371+ # Sum the attention mask (number of tokens without padding)
372+ sum_mask = att_mask .sum (1 , keepdim = True )
373+ # Calculate the mean embeddings
374+ embedding = sum_embeddings / sum_mask # (batch_size, hidden_size)
375+ else :
376+ raise ValueError (f"Invalid representation_type: { representation_type } . Must be 'tokens' or 'jumbo'." )
324377
325- dna_embeddings .append (mean_embeddings .cpu ().numpy ())
378+ dna_embeddings .append (embedding .cpu ().numpy ())
326379
327380 print (f"There are { len (df )} points in the dataset" )
381+ print (f"Using representation type: { representation_type } " )
328382 latent = np .array (dna_embeddings )
329383 latent = np .squeeze (latent , 1 )
330- print (latent .shape )
384+ print (f"Representation shape: { latent .shape } " )
331385 return latent , y , orders
332386
333387
0 commit comments