diff --git a/docs/source/inference_tutorials/sentence_transformers.mdx b/docs/source/inference_tutorials/sentence_transformers.mdx index 940430a87..fb26b104f 100644 --- a/docs/source/inference_tutorials/sentence_transformers.mdx +++ b/docs/source/inference_tutorials/sentence_transformers.mdx @@ -24,29 +24,29 @@ This guide explains how to compile, load, and use [Sentence Transformers (SBERT) ### Convert Sentence Transformers model to AWS Inferentia2 -First, you need to convert your Sentence Transformers model to a format compatible with AWS Inferentia2. You can compile Sentence Transformers models with Optimum Neuron using the `optimum-cli` or `NeuronModelForSentenceTransformers` class. Below you will find an example for both approaches. We have to make sure `sentence-transformers` is installed. That's only needed for exporting the model. +First, you need to convert your Sentence Transformers model to a format compatible with AWS Inferentia2. You can compile Sentence Transformers models with Optimum Neuron using the `optimum-cli` or `NeuronSentenceTransformers` class. Below you will find an example for both approaches. We have to make sure `sentence-transformers` is installed. That's only needed for exporting the model. ```bash pip install sentence-transformers ``` -Here we will use the `NeuronModelForSentenceTransformers`, which can be used to convert any Sentence Transformers model to a format compatible with AWS Inferentia2 or load already converted models. When exporting models with the `NeuronModelForSentenceTransformers` you need to set `export=True` and define the input shape and batch size. The input shape is defined by the `sequence_length` and the batch size by `batch_size`. +Here we will use the `NeuronSentenceTransformers`, which can be used to convert any Sentence Transformers model to a format compatible with AWS Inferentia2 or load already converted models. When exporting models with the `NeuronSentenceTransformers` you need to set `export=True` and define the input shape and batch size. The input shape is defined by the `sequence_length` and the batch size by `batch_size`. ```python -from optimum.neuron import NeuronModelForSentenceTransformers +from optimum.neuron import NeuronSentenceTransformers # Sentence Transformers model from HuggingFace model_id = "BAAI/bge-small-en-v1.5" input_shapes = {"batch_size": 1, "sequence_length": 384} # mandatory shapes # Load Transformers model and export it to AWS Inferentia2 -model = NeuronModelForSentenceTransformers.from_pretrained(model_id, export=True, **input_shapes) +model = NeuronSentenceTransformers.from_pretrained(model_id, export=True, **input_shapes) # Save model to disk model.save_pretrained("bge_emb_inf2/") ``` -Here we will use the `optimum-cli` to convert the model. Similar to the `NeuronModelForSentenceTransformers` we need to define our input shape and batch size. The input shape is defined by the `sequence_length` and the batch size by `batch_size`. The `optimum-cli` will automatically convert the model to a format compatible with AWS Inferentia2 and save it to the specified output directory. +Here we will use the `optimum-cli` to convert the model. Similar to the `NeuronSentenceTransformers` we need to define our input shape and batch size. The input shape is defined by the `sequence_length` and the batch size by `batch_size`. The `optimum-cli` will automatically convert the model to a format compatible with AWS Inferentia2 and save it to the specified output directory. ```bash optimum-cli export neuron -m BAAI/bge-small-en-v1.5 --sequence_length 384 --batch_size 1 --task feature-extraction bge_emb_inf2/ @@ -54,29 +54,19 @@ optimum-cli export neuron -m BAAI/bge-small-en-v1.5 --sequence_length 384 --batc ### Load compiled Sentence Transformers model and run inference -Once we have a compiled Sentence Transformers model, which we either exported ourselves or is available on the Hugging Face Hub, we can load it and run inference. For loading the model we can use the `NeuronModelForSentenceTransformers` class, which is an abstraction layer for the `SentenceTransformer` class. The `NeuronModelForSentenceTransformers` class will automatically pad the input to the specified `sequence_length` and run inference on AWS Inferentia2. +Once we have a compiled Sentence Transformers model, which we either exported ourselves or is available on the Hugging Face Hub, we can load it and run inference. For loading the model we can use the `NeuronSentenceTransformers` class, which is an abstraction layer for the `SentenceTransformer` class. The `NeuronSentenceTransformers` class will automatically pad the input to the specified `sequence_length` and run inference on AWS Inferentia2. ```python -from optimum.neuron import NeuronModelForSentenceTransformers -from transformers import AutoTokenizer +from optimum.neuron import NeuronSentenceTransformers model_id_or_path = "bge_emb_inf2/" -tokenizer_id = "BAAI/bge-small-en-v1.5" # Load model and tokenizer -model = NeuronModelForSentenceTransformers.from_pretrained(model_id_or_path) -tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) +model = NeuronSentenceTransformers.from_pretrained(model_id_or_path) # Run inference -prompt = "I like to eat apples" -encoded_input = tokenizer(prompt, return_tensors='pt') -outputs = model(**encoded_input) - -token_embeddings = outputs.token_embeddings -sentence_embedding = outputs.sentence_embedding - -print(f"token embeddings: {token_embeddings.shape}") # torch.Size([1, 7, 384]) -print(f"sentence_embedding: {sentence_embedding.shape}") # torch.Size([1, 384]) +token_embeddings = model.encode(output_value="token_embeddings") +sentence_embedding = model.encode(output_value="sentence_embedding") ``` ### Production Usage @@ -89,7 +79,7 @@ For deploying these models in a production environment, refer to the [Amazon Sag ### Compile CLIP for AWS Inferentia2 -You can compile CLIP models with Optimum Neuron either by using the `optimum-cli` or `NeuronModelForSentenceTransformers` class. Adopt one approach that you prefer: +You can compile CLIP models with Optimum Neuron either by using the `optimum-cli` or `NeuronSentenceTransformers` class. Adopt one approach that you prefer: * With the Optimum CLI @@ -97,10 +87,10 @@ You can compile CLIP models with Optimum Neuron either by using the `optimum-cli optimum-cli export neuron -m sentence-transformers/clip-ViT-B-32 --sequence_length 64 --text_batch_size 3 --image_batch_size 1 --num_channels 3 --height 224 --width 224 --task feature-extraction --subfolder 0_CLIPModel clip_emb/ ``` -* With the `NeuronModelForSentenceTransformers` class +* With the `NeuronSentenceTransformers` class ```python -from optimum.neuron import NeuronModelForSentenceTransformers +from optimum.neuron import NeuronSentenceTransformers model_id = "sentence-transformers/clip-ViT-B-32" @@ -114,7 +104,7 @@ input_shapes = { "sequence_length": 64, } -emb_model = NeuronModelForSentenceTransformers.from_pretrained( +emb_model = NeuronSentenceTransformers.from_pretrained( model_id, subfolder="0_CLIPModel", export=True, library_name="sentence_transformers", dynamic_batch_size=False, **input_shapes ) @@ -130,10 +120,10 @@ from PIL import Image from sentence_transformers import util from transformers import CLIPProcessor -from optimum.neuron import NeuronModelForSentenceTransformers +from optimum.neuron import NeuronSentenceTransformers save_directory = "clip_emb" -emb_model = NeuronModelForSentenceTransformers.from_pretrained(save_directory) +emb_model = NeuronSentenceTransformers.from_pretrained(save_directory) processor = CLIPProcessor.from_pretrained(save_directory) inputs = processor( @@ -154,7 +144,7 @@ print(cos_scores) **Caveat** -Since compiled models with dynamic batching enabled only accept input tensors with the same batch size, we cannot set `dynamic_batch_size=True` if the input texts and images have different batch sizes. And as `NeuronModelForSentenceTransformers` class pads the inputs to the batch sizes (`text_batch_size` and `image_batch_size`) used during the compilation, you could use relatively larger batch sizes during the compilation for flexibility with the trade-off of compute. +Since compiled models with dynamic batching enabled only accept input tensors with the same batch size, we cannot set `dynamic_batch_size=True` if the input texts and images have different batch sizes. And as `NeuronSentenceTransformers` class pads the inputs to the batch sizes (`text_batch_size` and `image_batch_size`) used during the compilation, you could use relatively larger batch sizes during the compilation for flexibility with the trade-off of compute. eg. if you want to encode 3 or 4 or 5 texts and 1 image, you could set `text_batch_size = 5 = max(3, 4, 5)` and `image_batch_size = 1` during the compilation. diff --git a/docs/source/model_doc/modeling_auto.mdx b/docs/source/model_doc/modeling_auto.mdx index 39be879f9..609cb0b17 100644 --- a/docs/source/model_doc/modeling_auto.mdx +++ b/docs/source/model_doc/modeling_auto.mdx @@ -33,9 +33,9 @@ The following Neuron model classes are available for natural language processing [[autodoc]] modeling.NeuronModelForFeatureExtraction -### NeuronModelForSentenceTransformers +### NeuronSentenceTransformers -[[autodoc]] modeling.NeuronModelForSentenceTransformers +[[autodoc]] modeling_sentence_transformers.NeuronSentenceTransformers ### NeuronModelForMaskedLM diff --git a/docs/source/model_doc/sentence_transformers/overview.mdx b/docs/source/model_doc/sentence_transformers/overview.mdx index 5ff302ec4..e451e02e3 100644 --- a/docs/source/model_doc/sentence_transformers/overview.mdx +++ b/docs/source/model_doc/sentence_transformers/overview.mdx @@ -39,16 +39,16 @@ optimum-cli export neuron -m sentence-transformers/clip-ViT-B-32 --sequence_leng * Example - Text embeddings ```python -from optimum.neuron import NeuronModelForSentenceTransformers +from optimum.neuron import NeuronSentenceTransformers # configs for compiling model input_shapes = { "batch_size": 1, - "sequence_length": 384, + "sequence_length": 512, } compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} -neuron_model = NeuronModelForSentenceTransformers.from_pretrained( +neuron_model = NeuronSentenceTransformers.from_pretrained( "BAAI/bge-large-en-v1.5", export=True, **input_shapes, @@ -63,12 +63,17 @@ neuron_model.push_to_hub( "bge_emb_neuron/", repository_id="optimum/bge-base-en-v1.5-neuronx" # Replace with your HF Hub repo id ) +sentences_1 = ["Life is pain au chocolat", "Life is galette des rois"] +sentences_2 = ["Life is eclaire au cafe", "Life is mille feuille"] +embeddings_1 = neuron_model.encode(sentences_1, normalize_embeddings=True) +embeddings_2 = neuron_model.encode(sentences_2, normalize_embeddings=True) +similarity = neuron_model.similarity(embeddings_1, embeddings_2) ``` * Example - Image Search ```python -from optimum.neuron import NeuronModelForSentenceTransformers +from optimum.neuron import NeuronSentenceTransformers # configs for compiling model input_shapes = { @@ -81,7 +86,7 @@ input_shapes = { } compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} -neuron_model = NeuronModelForSentenceTransformers.from_pretrained( +neuron_model = NeuronSentenceTransformers.from_pretrained( "sentence-transformers/clip-ViT-B-32", subfolder="0_CLIPModel", export=True, @@ -98,7 +103,3 @@ neuron_model.push_to_hub( "clip_emb_neuron/", repository_id="optimum/clip_vit_emb_neuronx" # Replace with your HF Hub repo id ) ``` - -## NeuronModelForSentenceTransformers - -[[autodoc]] modeling.NeuronModelForSentenceTransformers diff --git a/optimum/commands/neuron/cache.py b/optimum/commands/neuron/cache.py index cb437c347..e3e90c0b2 100644 --- a/optimum/commands/neuron/cache.py +++ b/optimum/commands/neuron/cache.py @@ -147,7 +147,7 @@ def _list_entries(self): str(entry["batch_size"]), str(entry["sequence_length"]), str(entry.get("tp_degree", entry.get("tensor_parallel_size"))), - str(entry["torch_dtype"]), + str(entry.get("torch_dtype", entry.get("dtype"))), str(entry["target"]), ) ) diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index 5511f8802..260ad04b6 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -43,7 +43,6 @@ "modeling_traced": ["NeuronTracedModel"], "modeling": [ "NeuronModelForFeatureExtraction", - "NeuronModelForSentenceTransformers", "NeuronModelForMaskedLM", "NeuronModelForQuestionAnswering", "NeuronModelForSequenceClassification", @@ -78,6 +77,7 @@ "modeling_seq2seq": [ "NeuronModelForSeq2SeqLM", ], + "modeling_sentence_transformers": ["NeuronSentenceTransformers"], "models": [], "accelerate": [ "NeuronAccelerator", @@ -115,7 +115,6 @@ NeuronModelForObjectDetection, NeuronModelForQuestionAnswering, NeuronModelForSemanticSegmentation, - NeuronModelForSentenceTransformers, NeuronModelForSequenceClassification, NeuronModelForTokenClassification, NeuronModelForXVector, @@ -138,6 +137,7 @@ NeuronStableDiffusionXLInpaintPipeline, NeuronStableDiffusionXLPipeline, ) + from .modeling_sentence_transformers import NeuronSentenceTransformers from .modeling_seq2seq import NeuronModelForSeq2SeqLM from .modeling_traced import NeuronTracedModel diff --git a/optimum/neuron/cache/hub_cache.py b/optimum/neuron/cache/hub_cache.py index 73605e4a3..cab833426 100644 --- a/optimum/neuron/cache/hub_cache.py +++ b/optimum/neuron/cache/hub_cache.py @@ -427,7 +427,7 @@ def select_hub_cached_entries( continue if torch_dtype is not None: target_value = DTYPE_MAPPER.pt(torch_dtype) if isinstance(torch_dtype, str) else torch_dtype - entry_value = DTYPE_MAPPER.pt(entry.get("torch_dtype")) + entry_value = DTYPE_MAPPER.pt(entry.get("torch_dtype", entry.get("dtype"))) if target_value != entry_value: continue selected.append(entry) diff --git a/optimum/neuron/modeling.py b/optimum/neuron/modeling.py index 1ce4756ed..9ce39008f 100644 --- a/optimum/neuron/modeling.py +++ b/optimum/neuron/modeling.py @@ -15,7 +15,6 @@ """NeuronModelForXXX classes for inference on neuron devices using the same API as Transformers.""" import logging -from typing import TYPE_CHECKING import torch from transformers import ( @@ -66,8 +65,6 @@ NEURON_OBJECT_DETECTION_EXAMPLE, NEURON_QUESTION_ANSWERING_EXAMPLE, NEURON_SEMANTIC_SEGMENTATION_EXAMPLE, - NEURON_SENTENCE_TRANSFORMERS_IMAGE_EXAMPLE, - NEURON_SENTENCE_TRANSFORMERS_TEXT_EXAMPLE, NEURON_SEQUENCE_CLASSIFICATION_EXAMPLE, NEURON_TEXT_INPUTS_DOCSTRING, NEURON_TOKEN_CLASSIFICATION_EXAMPLE, @@ -76,10 +73,6 @@ ) -if TYPE_CHECKING: - pass - - logger = logging.getLogger(__name__) @@ -135,72 +128,6 @@ def forward( return BaseModelOutputWithPooling(last_hidden_state=last_hidden_state, pooler_output=pooler_output) -@add_start_docstrings( - """ - Neuron Model for Sentence Transformers. - """, - NEURON_MODEL_START_DOCSTRING, -) -class NeuronModelForSentenceTransformers(NeuronTracedModel): - """ - Sentence Transformers model on Neuron devices. - """ - - auto_model_class = AutoModel - library_name = "sentence_transformers" - - @add_start_docstrings_to_model_forward( - NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length") - + NEURON_SENTENCE_TRANSFORMERS_TEXT_EXAMPLE.format( - processor_class=_TOKENIZER_FOR_DOC, - model_class="NeuronModelForSentenceTransformers", - checkpoint="optimum/bge-base-en-v1.5-neuronx", - ) - + NEURON_SENTENCE_TRANSFORMERS_IMAGE_EXAMPLE.format( - processor_class=_GENERIC_PROCESSOR, - model_class="NeuronModelForSentenceTransformers", - checkpoint="optimum/clip_vit_emb_neuronx", - ) - ) - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - pixel_values: torch.Tensor | None = None, - token_type_ids: torch.Tensor | None = None, - **kwargs, - ): - model_type = self.config.neuron["model_type"] - neuron_inputs = {"input_ids": input_ids} - if pixel_values is not None: - neuron_inputs["pixel_values"] = pixel_values - neuron_inputs["attention_mask"] = ( - attention_mask # The input order for clip is: input_ids, pixel_values, attention_mask. - ) - - with self.neuron_padding_manager(neuron_inputs) as inputs: - outputs = self.model(*inputs) - if "clip" in model_type: - text_embeds = self.remove_padding([outputs[0]], dims=[0], indices=[input_ids.shape[0]])[ - 0 - ] # Remove padding on batch_size(0) - image_embeds = self.remove_padding([outputs[1]], dims=[0], indices=[pixel_values.shape[0]])[ - 0 - ] # Remove padding on batch_size(0) - return ModelOutput(text_embeds=text_embeds, image_embeds=image_embeds) - else: - # token_embeddings -> (batch_size, sequencen_len, hidden_size) - token_embeddings = self.remove_padding( - [outputs[0]], dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]] - )[0] # Remove padding on batch_size(0), and sequence_length(1) - # sentence_embedding -> (batch_size, hidden_size) - sentence_embedding = self.remove_padding([outputs[1]], dims=[0], indices=[input_ids.shape[0]])[ - 0 - ] # Remove padding on batch_size(0) - - return ModelOutput(token_embeddings=token_embeddings, sentence_embedding=sentence_embedding) - - @add_start_docstrings( """ Neuron Model with a MaskedLMOutput for masked language modeling tasks. diff --git a/optimum/neuron/modeling_sentence_transformers.py b/optimum/neuron/modeling_sentence_transformers.py new file mode 100644 index 000000000..8573618db --- /dev/null +++ b/optimum/neuron/modeling_sentence_transformers.py @@ -0,0 +1,280 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NeuronSentenceTransformers class for inference on neuron devices of sentence transformers.""" + +import logging +from typing import Literal + +import torch +from sentence_transformers.similarity_functions import SimilarityFunction +from tqdm.autonotebook import trange +from transformers import AutoModel +from transformers.modeling_outputs import ModelOutput + +from .modeling_traced import NeuronTracedModel + + +logger = logging.getLogger(__name__) + + +class NeuronSentenceTransformers(NeuronTracedModel): + """ + Sentence Transformers model on Neuron devices. + """ + + auto_model_class = AutoModel + library_name = "sentence_transformers" + + def __init__( + self, + **kwargs, + ): + self.prompts = {"query": "", "document": ""} + prompts = kwargs.pop("prompts", None) + if prompts: + self.prompts.update(prompts) + self.default_prompt_name = kwargs.pop("default_prompt_nam", None) + self._prompt_length_mapping = {} + self.truncate_dim = kwargs.pop("truncate_dim", None) + self.similarity_fn_name = kwargs.pop("similarity_fn_name", None) + + super().__init__(**kwargs) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + pixel_values: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + token_embeddings: torch.Tensor | None = None, + sentence_embedding: torch.Tensor | None = None, + ): + model_type = self.config.neuron["model_type"] + neuron_inputs = {"input_ids": input_ids} + if pixel_values is not None: + neuron_inputs["pixel_values"] = pixel_values + neuron_inputs["attention_mask"] = ( + attention_mask # The input order for clip is: input_ids, pixel_values, attention_mask. + ) + + with self.neuron_padding_manager(neuron_inputs) as inputs: + outputs = self.model(*inputs) + if "clip" in model_type: + text_embeds = self.remove_padding([outputs[0]], dims=[0], indices=[input_ids.shape[0]])[ + 0 + ] # Remove padding on batch_size(0) + image_embeds = self.remove_padding([outputs[1]], dims=[0], indices=[pixel_values.shape[0]])[ + 0 + ] # Remove padding on batch_size(0) + return ModelOutput(text_embeds=text_embeds, image_embeds=image_embeds) + else: + # token_embeddings -> (batch_size, sequencen_len, hidden_size) + token_embeddings = self.remove_padding( + [outputs[0]], dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]] + )[0] # Remove padding on batch_size(0), and sequence_length(1) + # sentence_embedding -> (batch_size, hidden_size) + sentence_embedding = self.remove_padding([outputs[1]], dims=[0], indices=[input_ids.shape[0]])[ + 0 + ] # Remove padding on batch_size(0) + + return ModelOutput( + token_embeddings=token_embeddings, + sentence_embedding=sentence_embedding, + attention_mask=attention_mask, + ) + + def tokenize(self, texts: list[str] | list[dict] | list[tuple[str, str]], **kwargs) -> dict[str, torch.Tensor]: + """ + Tokenizes the texts. + + Args: + texts (list[str] | list[dict] | list[tuple[str, str]]]): A list of texts to be tokenized. + + Returns: + dict[str, torch.Tensor]: A dictionary of tensors with the tokenized texts. Common keys are "input_ids", + "attention_mask", and "token_type_ids". + """ + return self.preprocessors[0](texts, **kwargs) + + def _get_prompt_length(self, prompt: str, **kwargs) -> int | None: + """ + Return the length of the prompt in tokens, including the BOS token + """ + if (prompt, *kwargs.values()) in self._prompt_length_mapping: + return self._prompt_length_mapping[(prompt, *kwargs.values())] + + tokenized_prompt = self.tokenize([prompt], return_tensors="pt", **kwargs) + if "input_ids" not in tokenized_prompt: + # If the tokenizer does not return input_ids, we cannot determine the prompt length. + # This can happen with some tokenizers that do not use input_ids. + return None + prompt_length = tokenized_prompt["input_ids"].shape[-1] + # If the tokenizer adds a special EOS token, we do not count it as part of the prompt length. + # This is to ensure that the prompt length does not include the EOS token. + last_token = tokenized_prompt["input_ids"][..., -1].item() + if hasattr(self.tokenizer, "all_special_ids") and last_token in self.tokenizer.all_special_ids: + prompt_length -= 1 + self._prompt_length_mapping[(prompt, *kwargs.values())] = prompt_length + return prompt_length + + def _text_length(self, text: list[int] | list[list[int]] | dict) -> int: + """ + Help function to get the length for the input text. Text can be either + a list of ints (which means a single text as input), or a tuple of list of ints + (representing several text inputs to the model). + """ + + if isinstance(text, dict): # {key: value} case + return len(next(iter(text.values()))) + elif not hasattr(text, "__len__"): # Object has no len() method + return 1 + elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints + return len(text) + else: + return sum([len(t) for t in text]) # Sum of length of individual strings + + @property + def similarity_fn_name(self) -> Literal["cosine", "dot", "euclidean", "manhattan"]: + """Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`. + + Returns: + Optional[str]: The name of the similarity function. Can be None if not set, in which case it will + default to "cosine" when first called. + + Example: + >>> model = SentenceTransformer("multi-qa-mpnet-base-dot-v1") + >>> model.similarity_fn_name + 'dot' + """ + if self._similarity_fn_name is None: + self.similarity_fn_name = SimilarityFunction.COSINE + return self._similarity_fn_name + + @similarity_fn_name.setter + def similarity_fn_name( + self, value: Literal["cosine", "dot", "euclidean", "manhattan"] | SimilarityFunction | None + ) -> None: + if isinstance(value, SimilarityFunction): + value = value.value + self._similarity_fn_name = value + + if value is not None: + self._similarity = SimilarityFunction.to_similarity_fn(value) + self._similarity_pairwise = SimilarityFunction.to_similarity_pairwise_fn(value) + + def encode( + self, + sentences: str | list[str], + prompt_name: str | None = None, + prompt: str | None = None, + show_progress_bar: bool | None = None, + output_value: Literal["sentence_embedding", "token_embeddings"] | None = "sentence_embedding", + normalize_embeddings: bool = False, + **kwargs, + ): + if show_progress_bar is None: + show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG) + + # Cast an individual input to a list with length 1 + input_was_string = False + if isinstance(sentences, str) or not hasattr(sentences, "__len__"): + sentences = [sentences] + input_was_string = True + + if prompt is None: + if prompt_name is not None: + try: + prompt = self.prompts[prompt_name] + except KeyError: + raise ValueError( + f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(self.prompts.keys())!r}." + ) + elif self.default_prompt_name is not None: + prompt = self.prompts.get(self.default_prompt_name, None) + else: + if prompt_name is not None: + logger.warning( + "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. " + "Ignoring the `prompt_name` in favor of `prompt`." + ) + + extra_features = {} + if prompt is not None and len(prompt) > 0: + sentences = [prompt + sentence for sentence in sentences] + + # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling + # Tracking the prompt length allow us to remove the prompt during pooling + length = self._get_prompt_length(prompt, **kwargs) + if length is not None: + extra_features["prompt_length"] = length + + all_embeddings = [] + lengths = torch.tensor([self._text_length(sen) for sen in sentences]) + length_sorted_idx = torch.argsort(-lengths) + sentences_sorted = [sentences[int(idx)] for idx in length_sorted_idx] + + batch_size = self.neuron_config.batch_size + for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): + sentences_batch = sentences_sorted[start_index : start_index + batch_size] + features = self.tokenize(sentences_batch, return_tensors="pt", **kwargs) + + features.update(extra_features) + out_features = self.forward(**features) + + if output_value == "token_embeddings": + embeddings = [] + for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]): + last_mask_id = len(attention) - 1 + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 + + embeddings.append(token_emb[0 : last_mask_id + 1]) + elif output_value is None: # Return all outputs + embeddings = [] + for idx in range(len(out_features["sentence_embedding"])): + batch_item = {} + for name, value in out_features.items(): + try: + batch_item[name] = value[idx] + except TypeError: + # Handle non-indexable values (like prompt_length) + batch_item[name] = value + embeddings.append(batch_item) + else: # Sentence embeddings + embeddings = out_features[output_value] + embeddings = embeddings.detach() + if normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + all_embeddings.extend(embeddings) + + all_embeddings = [all_embeddings[i] for i in length_sorted_idx.tolist()] + + if len(all_embeddings): + all_embeddings = torch.stack(all_embeddings) + else: + all_embeddings = torch.tensor([], device=self.device) + + if input_was_string: + all_embeddings = all_embeddings[0] + + return all_embeddings + + @property + def similarity(self): + if self.similarity_fn_name is None: + self.similarity_fn_name = SimilarityFunction.COSINE + + return self._similarity diff --git a/optimum/neuron/pipelines/transformers/base.py b/optimum/neuron/pipelines/transformers/base.py index 14f3f9468..e17c6bfa1 100644 --- a/optimum/neuron/pipelines/transformers/base.py +++ b/optimum/neuron/pipelines/transformers/base.py @@ -56,10 +56,10 @@ NeuronModelForMaskedLM, NeuronModelForQuestionAnswering, NeuronModelForSemanticSegmentation, - NeuronModelForSentenceTransformers, NeuronModelForSequenceClassification, NeuronModelForTokenClassification, ) +from ...modeling_sentence_transformers import NeuronSentenceTransformers from ...models.inference.modeling_utils import NeuronModelForCausalLM @@ -171,7 +171,7 @@ def load_pipeline( model, token=token, revision=revision ): logger.info("Using Sentence Transformers compatible Feature extraction pipeline") - neuronx_model_class = NeuronModelForSentenceTransformers + neuronx_model_class = NeuronSentenceTransformers if issubclass(neuronx_model_class, NeuronModelForCausalLM): if export: diff --git a/tests/inference/transformers/test_modeling.py b/tests/inference/transformers/test_modeling.py index 9f5f91379..8125f6247 100644 --- a/tests/inference/transformers/test_modeling.py +++ b/tests/inference/transformers/test_modeling.py @@ -60,10 +60,10 @@ NeuronModelForObjectDetection, NeuronModelForQuestionAnswering, NeuronModelForSemanticSegmentation, - NeuronModelForSentenceTransformers, NeuronModelForSequenceClassification, NeuronModelForTokenClassification, NeuronModelForXVector, + NeuronSentenceTransformers, NeuronTracedModel, pipeline, ) @@ -302,8 +302,8 @@ def test_pipeline_model(self): @is_inferentia_test -class NeuronModelForSentenceTransformersIntegrationTest(NeuronModelTestMixin): - NEURON_MODEL_CLASS = NeuronModelForSentenceTransformers +class NeuronSentenceTransformersIntegrationTest(NeuronModelTestMixin): + NEURON_MODEL_CLASS = NeuronSentenceTransformers TASK = "feature-extraction" ATOL_FOR_VALIDATION = 1e-2 SUPPORTED_ARCHITECTURES = ["transformer", "clip"] @@ -327,10 +327,9 @@ def test_sentence_transformers_dyn_bs(self, model_arch): set_seed(SEED) sentence_transformers_model = SentenceTransformer(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) text = ["This is a sample output"] * 2 - tokens = tokenizer(text, return_tensors="pt") + tokens = neuron_model_dyn.tokenize(text, return_tensors="pt") with torch.no_grad(): sentence_transformers_outputs = sentence_transformers_model(tokens) @@ -359,6 +358,14 @@ def test_sentence_transformers_dyn_bs(self, model_arch): ) ) + # Encode + Similarity + sentences_1 = ["Life is pain au chocolat", "Life is galette des rois"] + sentences_2 = ["Life is eclaire au cafe", "Life is mille feuille"] + embeddings_1 = neuron_model_dyn.encode(sentences_1, normalize_embeddings=True) + embeddings_2 = neuron_model_dyn.encode(sentences_2, normalize_embeddings=True) + similarity = neuron_model_dyn.similarity(embeddings_1, embeddings_2) + self.assertIsInstance(similarity, torch.Tensor) + gc.collect() @parameterized.expand(["clip"], skip_on_empty=True)