diff --git a/transformer_ranker/embedder.py b/transformer_ranker/embedder.py index 4290f87..d6eca31 100644 --- a/transformer_ranker/embedder.py +++ b/transformer_ranker/embedder.py @@ -21,7 +21,7 @@ def __init__( device: Optional[str] = None, ): """ - Generates word or text embeddings using a pre-trained model. + Generates word or text embeddings using a pre-trained model. Does sub-word and sequence (sentence) pooling. :param model: Model name or instance. @@ -89,9 +89,13 @@ def _embed_batch(self, sentences, unpack_to_cpu: bool = True) -> list[torch.Tens tokenized = self._tokenize(sentences) word_ids = tokenized.pop("word_ids") - # Forward pass to get all hidden states + # Do forward pass and get all hidden states with torch.no_grad(): - hidden_states = self.model(**tokenized, output_hidden_states=True).hidden_states + outputs = self.model(**tokenized, output_hidden_states=True) + if hasattr(outputs, "hidden_states"): + hidden_states = outputs.hidden_states + else: + raise ValueError(f"Failed to get hidden states for model: {self.name}") # Exclude the embedding layer (index 0) embeddings = torch.stack(hidden_states[1:], dim=0) @@ -126,7 +130,11 @@ def _setup_model(self, model: Union[str, torch.nn.Module], local_files_only: boo if isinstance(model, torch.nn.Module) else AutoModel.from_pretrained(model, local_files_only=local_files_only) ) - self.name = self.model.config.name_or_path + + if hasattr(self.model.config, 'is_encoder_decoder') and self.model.config.is_encoder_decoder: + self.model = self.model.encoder # remove decoder + + self.name = getattr(self.model.config, "name_or_path", "Unknown Model") def _setup_tokenizer(self, tokenizer: Optional[Union[str, PreTrainedTokenizerFast]]) -> None: """Initialize tokenizer using AutoTokenizer, support PreTokenizerFast."""