diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index c497dadbc..17b6bd38b 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -526,6 +526,7 @@ def encode( print(embeddings.shape) # (3, 768) """ + kwargs["normalize_embeddings"] = normalize_embeddings if self.device.type == "hpu" and not self.is_hpu_graph_enabled: import habana_frameworks.torch as ht @@ -685,6 +686,8 @@ def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]: return super().forward(input) for module_name, module in self.named_children(): + if not kwargs.get("normalize_embeddings", False) and isinstance(module, Normalize): + return input module_kwarg_keys = self.module_kwargs.get(module_name, []) module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys} input = module(input, **module_kwargs)