@@ -3488,14 +3488,32 @@ def merge_models(cls, models, min_similarity: float = 0.7, embedding_model=None)
34883488 merged_model = BERTopic.merge_models([topic_model_1, topic_model_2, topic_model_3])
34893489 ```
34903490 """
3491- import torch
3491+
3492+ def choose_backend ():
3493+ """Choose the backend to use for saving the model."""
3494+ try :
3495+ import torch # noqa: F401
3496+
3497+ return "pytorch"
3498+ except (ModuleNotFoundError , ImportError ):
3499+ try :
3500+ import safetensors # noqa: F401
3501+
3502+ return "safetensors"
3503+ except (ModuleNotFoundError , ImportError ):
3504+ raise ImportError (
3505+ "Neither pytorch nor safetensors is installed. "
3506+ "Please install at least one of these packages:\n "
3507+ " pip install torch\n "
3508+ " pip install safetensors"
3509+ )
34923510
34933511 # Temporarily save model and push to HF
34943512 with TemporaryDirectory () as tmpdir :
34953513 # Save model weights and config.
34963514 all_topics , all_params , all_tensors = [], [], []
34973515 for index , model in enumerate (models ):
3498- model .save (tmpdir , serialization = "pytorch" )
3516+ model .save (tmpdir , serialization = choose_backend () )
34993517 topics , params , tensors , _ , _ , _ = save_utils .load_local_files (Path (tmpdir ))
35003518 all_topics .append (topics )
35013519 all_params .append (params )
@@ -3570,7 +3588,7 @@ def merge_models(cls, models, min_similarity: float = 0.7, embedding_model=None)
35703588 merged_topics ["topic_sizes" ] = dict (Counter (merged_topics ["topics" ]))
35713589
35723590 # Create a new model from the merged parameters
3573- merged_tensors = {"topic_embeddings" : torch . from_numpy ( merged_tensors ) }
3591+ merged_tensors = {"topic_embeddings" : merged_tensors }
35743592 merged_model = _create_model_from_files (
35753593 merged_topics ,
35763594 merged_params ,
0 commit comments