Skip to content

Commit 8976e64

Browse files
authored
Lightweight installation: use safetensors without torch (#2306)
1 parent ba2cce8 commit 8976e64

File tree

2 files changed

+32
-25
lines changed

2 files changed

+32
-25
lines changed

bertopic/_bertopic.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4849,19 +4849,18 @@ def _create_model_from_files(
48494849
images: The images per topic
48504850
warn_no_backend: Whether to warn the user if no backend is given
48514851
"""
4852-
from sentence_transformers import SentenceTransformer
4853-
48544852
params["n_gram_range"] = tuple(params["n_gram_range"])
48554853

48564854
if ctfidf_config is not None:
48574855
ngram_range = ctfidf_config["vectorizer_model"]["params"]["ngram_range"]
48584856
ctfidf_config["vectorizer_model"]["params"]["ngram_range"] = tuple(ngram_range)
48594857

48604858
params["n_gram_range"] = tuple(params["n_gram_range"])
4861-
ctfidf_config
48624859

48634860
# Select HF model through SentenceTransformers
48644861
try:
4862+
from sentence_transformers import SentenceTransformer
4863+
48654864
embedding_model = select_backend(SentenceTransformer(params["embedding_model"]))
48664865
except: # noqa: E722
48674866
embedding_model = BaseEmbedder()
@@ -4887,7 +4886,7 @@ def _create_model_from_files(
48874886
hdbscan_model=empty_cluster_model,
48884887
**params,
48894888
)
4890-
topic_model.topic_embeddings_ = tensors["topic_embeddings"].numpy()
4889+
topic_model.topic_embeddings_ = tensors["topic_embeddings"]
48914890
topic_model.topic_representations_ = {int(key): val for key, val in topics["topic_representations"].items()}
48924891
topic_model.topics_ = topics["topics"]
48934892
topic_model.topic_sizes_ = {int(key): val for key, val in topics["topic_sizes"].items()}
@@ -4924,7 +4923,7 @@ def _create_model_from_files(
49244923
# ClassTfidfTransformer
49254924
topic_model.ctfidf_model.reduce_frequent_words = ctfidf_config["ctfidf_model"]["reduce_frequent_words"]
49264925
topic_model.ctfidf_model.bm25_weighting = ctfidf_config["ctfidf_model"]["bm25_weighting"]
4927-
idf = ctfidf_tensors["diag"].numpy()
4926+
idf = ctfidf_tensors["diag"]
49284927
topic_model.ctfidf_model._idf_diag = sp.diags(
49294928
idf, offsets=0, shape=(len(idf), len(idf)), format="csr", dtype=np.float64
49304929
)

bertopic/_save_utils.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def load_local_files(path):
185185
torch_path = path / HF_WEIGHTS_NAME
186186
if torch_path.is_file():
187187
tensors = torch.load(torch_path, map_location="cpu")
188+
tensors = {k: v.numpy() for k, v in tensors.items()}
188189

189190
# c-TF-IDF
190191
try:
@@ -196,6 +197,7 @@ def load_local_files(path):
196197
torch_path = path / CTFIDF_WEIGHTS_NAME
197198
if torch_path.is_file():
198199
ctfidf_tensors = torch.load(torch_path, map_location="cpu")
200+
ctfidf_tensors = {k: v.numpy() for k, v in ctfidf_tensors.items()}
199201
ctfidf_config = load_cfg_from_json(path / CTFIDF_CFG_NAME)
200202
except: # noqa: E722
201203
ctfidf_config, ctfidf_tensors = None, None
@@ -315,35 +317,43 @@ def generate_readme(model, repo_id: str):
315317

316318
def save_hf(model, save_directory, serialization: str):
317319
"""Save topic embeddings, either safely (using safetensors) or using legacy pytorch."""
318-
tensors = torch.from_numpy(np.array(model.topic_embeddings_, dtype=np.float32))
319-
tensors = {"topic_embeddings": tensors}
320+
tensors = np.array(model.topic_embeddings_, dtype=np.float32)
320321

321322
if serialization == "safetensors":
323+
tensors = {"topic_embeddings": tensors}
322324
save_safetensors(save_directory / HF_SAFE_WEIGHTS_NAME, tensors)
323325
if serialization == "pytorch":
324326
assert _has_torch, "`pip install pytorch` to save as bin"
327+
tensors = {"topic_embeddings": torch.from_numpy(tensors)}
325328
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
326329

327330

328331
def save_ctfidf(model, save_directory: str, serialization: str):
329332
"""Save c-TF-IDF sparse matrix."""
330-
indptr = torch.from_numpy(model.c_tf_idf_.indptr)
331-
indices = torch.from_numpy(model.c_tf_idf_.indices)
332-
data = torch.from_numpy(model.c_tf_idf_.data)
333-
shape = torch.from_numpy(np.array(model.c_tf_idf_.shape))
334-
diag = torch.from_numpy(np.array(model.ctfidf_model._idf_diag.data))
335-
tensors = {
336-
"indptr": indptr,
337-
"indices": indices,
338-
"data": data,
339-
"shape": shape,
340-
"diag": diag,
341-
}
333+
indptr = model.c_tf_idf_.indptr
334+
indices = model.c_tf_idf_.indices
335+
data = model.c_tf_idf_.data
336+
shape = np.array(model.c_tf_idf_.shape)
337+
diag = np.array(model.ctfidf_model._idf_diag.data)
342338

343339
if serialization == "safetensors":
340+
tensors = {
341+
"indptr": indptr,
342+
"indices": indices,
343+
"data": data,
344+
"shape": shape,
345+
"diag": diag,
346+
}
344347
save_safetensors(save_directory / CTFIDF_SAFE_WEIGHTS_NAME, tensors)
345348
if serialization == "pytorch":
346349
assert _has_torch, "`pip install pytorch` to save as .bin"
350+
tensors = {
351+
"indptr": torch.from_numpy(indptr),
352+
"indices": torch.from_numpy(indices),
353+
"data": torch.from_numpy(data),
354+
"shape": torch.from_numpy(shape),
355+
"diag": torch.from_numpy(diag),
356+
}
347357
torch.save(tensors, save_directory / CTFIDF_WEIGHTS_NAME)
348358

349359

@@ -511,20 +521,18 @@ def get_package_versions():
511521
def load_safetensors(path):
512522
"""Load safetensors and check whether it is installed."""
513523
try:
514-
import safetensors.torch
515-
import safetensors
524+
import safetensors.numpy
516525

517-
return safetensors.torch.load_file(path, device="cpu")
526+
return safetensors.numpy.load_file(path)
518527
except ImportError:
519528
raise ValueError("`pip install safetensors` to load .safetensors")
520529

521530

522531
def save_safetensors(path, tensors):
523532
"""Save safetensors and check whether it is installed."""
524533
try:
525-
import safetensors.torch
526-
import safetensors
534+
import safetensors.numpy
527535

528-
safetensors.torch.save_file(tensors, path)
536+
safetensors.numpy.save_file(tensors, path)
529537
except ImportError:
530538
raise ValueError("`pip install safetensors` to save as .safetensors")

0 commit comments

Comments
 (0)