Skip to content

Commit a7abdb7

Browse files
authored
Merge models without pytorch (using safetensors) (#2329)
1 parent de250e9 commit a7abdb7

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

bertopic/_bertopic.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3488,14 +3488,32 @@ def merge_models(cls, models, min_similarity: float = 0.7, embedding_model=None)
34883488
merged_model = BERTopic.merge_models([topic_model_1, topic_model_2, topic_model_3])
34893489
```
34903490
"""
3491-
import torch
3491+
3492+
def choose_backend():
3493+
"""Choose the backend to use for saving the model."""
3494+
try:
3495+
import torch # noqa: F401
3496+
3497+
return "pytorch"
3498+
except (ModuleNotFoundError, ImportError):
3499+
try:
3500+
import safetensors # noqa: F401
3501+
3502+
return "safetensors"
3503+
except (ModuleNotFoundError, ImportError):
3504+
raise ImportError(
3505+
"Neither pytorch nor safetensors is installed. "
3506+
"Please install at least one of these packages:\n"
3507+
" pip install torch\n"
3508+
" pip install safetensors"
3509+
)
34923510

34933511
# Temporarily save model and push to HF
34943512
with TemporaryDirectory() as tmpdir:
34953513
# Save model weights and config.
34963514
all_topics, all_params, all_tensors = [], [], []
34973515
for index, model in enumerate(models):
3498-
model.save(tmpdir, serialization="pytorch")
3516+
model.save(tmpdir, serialization=choose_backend())
34993517
topics, params, tensors, _, _, _ = save_utils.load_local_files(Path(tmpdir))
35003518
all_topics.append(topics)
35013519
all_params.append(params)
@@ -3570,7 +3588,7 @@ def merge_models(cls, models, min_similarity: float = 0.7, embedding_model=None)
35703588
merged_topics["topic_sizes"] = dict(Counter(merged_topics["topics"]))
35713589

35723590
# Create a new model from the merged parameters
3573-
merged_tensors = {"topic_embeddings": torch.from_numpy(merged_tensors)}
3591+
merged_tensors = {"topic_embeddings": merged_tensors}
35743592
merged_model = _create_model_from_files(
35753593
merged_topics,
35763594
merged_params,

0 commit comments

Comments
 (0)