Skip to content

Commit b41c3be

Browse files
feat: Add optional embedding normalization to StaticModel loading (#164)
1 parent 9e55123 commit b41c3be

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

model2vec/model.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
:param vectors: The vectors to use.
3636
:param tokenizer: The Transformers tokenizer to use.
3737
:param config: Any metadata config.
38-
:param normalize: Whether to normalize.
38+
:param normalize: Whether to normalize the embeddings.
3939
:param base_model_name: The used base model name. Used for creating a model card.
4040
:param language: The language of the model. Used for creating a model card.
4141
:raises: ValueError if the number of tokens does not match the number of vectors.
@@ -149,6 +149,7 @@ def from_pretrained(
149149
cls: type[StaticModel],
150150
path: PathLike,
151151
token: str | None = None,
152+
normalize: bool | None = None,
152153
) -> StaticModel:
153154
"""
154155
Load a StaticModel from a local path or huggingface hub path.
@@ -157,21 +158,28 @@ def from_pretrained(
157158
158159
:param path: The path to load your static model from.
159160
:param token: The huggingface token to use.
161+
:param normalize: Whether to normalize the embeddings.
160162
:return: A StaticModel
161163
"""
162164
from model2vec.hf_utils import load_pretrained
163165

164166
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False)
165167

166168
return cls(
167-
embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language")
169+
embeddings,
170+
tokenizer,
171+
config,
172+
normalize=normalize,
173+
base_model_name=metadata.get("base_model"),
174+
language=metadata.get("language"),
168175
)
169176

170177
@classmethod
171178
def from_sentence_transformers(
172179
cls: type[StaticModel],
173180
path: PathLike,
174181
token: str | None = None,
182+
normalize: bool | None = None,
175183
) -> StaticModel:
176184
"""
177185
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
@@ -180,13 +188,14 @@ def from_sentence_transformers(
180188
181189
:param path: The path to load your static model from.
182190
:param token: The huggingface token to use.
191+
:param normalize: Whether to normalize the embeddings.
183192
:return: A StaticModel
184193
"""
185194
from model2vec.hf_utils import load_pretrained
186195

187196
embeddings, tokenizer, config, _ = load_pretrained(path, token=token, from_sentence_transformers=True)
188197

189-
return cls(embeddings, tokenizer, config, base_model_name=None, language=None)
198+
return cls(embeddings, tokenizer, config, normalize=normalize, base_model_name=None, language=None)
190199

191200
def encode_as_sequence(
192201
self,

0 commit comments

Comments
 (0)