@@ -21,7 +21,7 @@ def __init__(
2121 device : Optional [str ] = None ,
2222 ):
2323 """
24- Generates word or text embeddings using a pre-trained model.
24+ Generates word or text embeddings using a pre-trained model.
2525 Does sub-word and sequence (sentence) pooling.
2626
2727 :param model: Model name or instance.
@@ -89,9 +89,13 @@ def _embed_batch(self, sentences, unpack_to_cpu: bool = True) -> list[torch.Tens
8989 tokenized = self ._tokenize (sentences )
9090 word_ids = tokenized .pop ("word_ids" )
9191
92- # Forward pass to get all hidden states
92+ # Do forward pass and get all hidden states
9393 with torch .no_grad ():
94- hidden_states = self .model (** tokenized , output_hidden_states = True ).hidden_states
94+ outputs = self .model (** tokenized , output_hidden_states = True )
95+ if hasattr (outputs , "hidden_states" ):
96+ hidden_states = outputs .hidden_states
97+ else :
98+ raise ValueError (f"Failed to get hidden states for model: { self .name } " )
9599
96100 # Exclude the embedding layer (index 0)
97101 embeddings = torch .stack (hidden_states [1 :], dim = 0 )
@@ -126,7 +130,11 @@ def _setup_model(self, model: Union[str, torch.nn.Module], local_files_only: boo
126130 if isinstance (model , torch .nn .Module )
127131 else AutoModel .from_pretrained (model , local_files_only = local_files_only )
128132 )
129- self .name = self .model .config .name_or_path
133+
134+ if hasattr (self .model .config , 'is_encoder_decoder' ) and self .model .config .is_encoder_decoder :
135+ self .model = self .model .encoder # remove decoder
136+
137+ self .name = getattr (self .model .config , "name_or_path" , "Unknown Model" )
130138
131139 def _setup_tokenizer (self , tokenizer : Optional [Union [str , PreTrainedTokenizerFast ]]) -> None :
132140 """Initialize tokenizer using AutoTokenizer, support PreTokenizerFast."""
0 commit comments