Skip to content

Commit 07d43c3

Browse files
authored
Merge pull request #14 from flairNLP/encoder-decoder-support
Support encoder-decoder models (BART, T5)
2 parents 5f9fb20 + 00c3ced commit 07d43c3

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

transformer_ranker/embedder.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(
2121
device: Optional[str] = None,
2222
):
2323
"""
24-
Generates word or text embeddings using a pre-trained model.
24+
Generates word or text embeddings using a pre-trained model.
2525
Does sub-word and sequence (sentence) pooling.
2626
2727
:param model: Model name or instance.
@@ -89,9 +89,13 @@ def _embed_batch(self, sentences, unpack_to_cpu: bool = True) -> list[torch.Tens
8989
tokenized = self._tokenize(sentences)
9090
word_ids = tokenized.pop("word_ids")
9191

92-
# Forward pass to get all hidden states
92+
# Do forward pass and get all hidden states
9393
with torch.no_grad():
94-
hidden_states = self.model(**tokenized, output_hidden_states=True).hidden_states
94+
outputs = self.model(**tokenized, output_hidden_states=True)
95+
if hasattr(outputs, "hidden_states"):
96+
hidden_states = outputs.hidden_states
97+
else:
98+
raise ValueError(f"Failed to get hidden states for model: {self.name}")
9599

96100
# Exclude the embedding layer (index 0)
97101
embeddings = torch.stack(hidden_states[1:], dim=0)
@@ -126,7 +130,11 @@ def _setup_model(self, model: Union[str, torch.nn.Module], local_files_only: boo
126130
if isinstance(model, torch.nn.Module)
127131
else AutoModel.from_pretrained(model, local_files_only=local_files_only)
128132
)
129-
self.name = self.model.config.name_or_path
133+
134+
if hasattr(self.model.config, 'is_encoder_decoder') and self.model.config.is_encoder_decoder:
135+
self.model = self.model.encoder # remove decoder
136+
137+
self.name = getattr(self.model.config, "name_or_path", "Unknown Model")
130138

131139
def _setup_tokenizer(self, tokenizer: Optional[Union[str, PreTrainedTokenizerFast]]) -> None:
132140
"""Initialize tokenizer using AutoTokenizer, support PreTokenizerFast."""

0 commit comments

Comments
 (0)