Skip to content

Commit 00c3ced

Browse files
committed
Detect model type via config
1 parent 717e941 commit 00c3ced

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

transformer_ranker/embedder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _embed_batch(self, sentences, unpack_to_cpu: bool = True) -> list[torch.Tens
9595
if hasattr(outputs, "hidden_states"):
9696
hidden_states = outputs.hidden_states
9797
else:
98-
raise ValueError(f"Failed to get hidden states from model: {self.name}")
98+
raise ValueError(f"Failed to get hidden states for model: {self.name}")
9999

100100
# Exclude the embedding layer (index 0)
101101
embeddings = torch.stack(hidden_states[1:], dim=0)
@@ -131,7 +131,9 @@ def _setup_model(self, model: Union[str, torch.nn.Module], local_files_only: boo
131131
else AutoModel.from_pretrained(model, local_files_only=local_files_only)
132132
)
133133

134-
self.model = getattr(self.model, "encoder", self.model) # set to encoder for encoder-decoder models
134+
if hasattr(self.model.config, 'is_encoder_decoder') and self.model.config.is_encoder_decoder:
135+
self.model = self.model.encoder # remove decoder
136+
135137
self.name = getattr(self.model.config, "name_or_path", "Unknown Model")
136138

137139
def _setup_tokenizer(self, tokenizer: Optional[Union[str, PreTrainedTokenizerFast]]) -> None:

0 commit comments

Comments
 (0)