Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions transformer_ranker/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
device: Optional[str] = None,
):
"""
Generates word or text embeddings using a pre-trained model.
Generates word or text embeddings using a pre-trained model.
Does sub-word and sequence (sentence) pooling.
:param model: Model name or instance.
Expand Down Expand Up @@ -89,9 +89,13 @@ def _embed_batch(self, sentences, unpack_to_cpu: bool = True) -> list[torch.Tens
tokenized = self._tokenize(sentences)
word_ids = tokenized.pop("word_ids")

# Forward pass to get all hidden states
# Do forward pass and get all hidden states
with torch.no_grad():
hidden_states = self.model(**tokenized, output_hidden_states=True).hidden_states
outputs = self.model(**tokenized, output_hidden_states=True)
if hasattr(outputs, "hidden_states"):
hidden_states = outputs.hidden_states
else:
raise ValueError(f"Failed to get hidden states for model: {self.name}")

# Exclude the embedding layer (index 0)
embeddings = torch.stack(hidden_states[1:], dim=0)
Expand Down Expand Up @@ -126,7 +130,11 @@ def _setup_model(self, model: Union[str, torch.nn.Module], local_files_only: boo
if isinstance(model, torch.nn.Module)
else AutoModel.from_pretrained(model, local_files_only=local_files_only)
)
self.name = self.model.config.name_or_path

if hasattr(self.model.config, 'is_encoder_decoder') and self.model.config.is_encoder_decoder:
self.model = self.model.encoder # remove decoder

self.name = getattr(self.model.config, "name_or_path", "Unknown Model")

def _setup_tokenizer(self, tokenizer: Optional[Union[str, PreTrainedTokenizerFast]]) -> None:
"""Initialize tokenizer using AutoTokenizer, support PreTokenizerFast."""
Expand Down