@@ -185,6 +185,7 @@ def load_local_files(path):
185185 torch_path = path / HF_WEIGHTS_NAME
186186 if torch_path .is_file ():
187187 tensors = torch .load (torch_path , map_location = "cpu" )
188+ tensors = {k : v .numpy () for k , v in tensors .items ()}
188189
189190 # c-TF-IDF
190191 try :
@@ -196,6 +197,7 @@ def load_local_files(path):
196197 torch_path = path / CTFIDF_WEIGHTS_NAME
197198 if torch_path .is_file ():
198199 ctfidf_tensors = torch .load (torch_path , map_location = "cpu" )
200+ ctfidf_tensors = {k : v .numpy () for k , v in ctfidf_tensors .items ()}
199201 ctfidf_config = load_cfg_from_json (path / CTFIDF_CFG_NAME )
200202 except : # noqa: E722
201203 ctfidf_config , ctfidf_tensors = None , None
@@ -315,35 +317,43 @@ def generate_readme(model, repo_id: str):
315317
316318def save_hf (model , save_directory , serialization : str ):
317319 """Save topic embeddings, either safely (using safetensors) or using legacy pytorch."""
318- tensors = torch .from_numpy (np .array (model .topic_embeddings_ , dtype = np .float32 ))
319- tensors = {"topic_embeddings" : tensors }
320+ tensors = np .array (model .topic_embeddings_ , dtype = np .float32 )
320321
321322 if serialization == "safetensors" :
323+ tensors = {"topic_embeddings" : tensors }
322324 save_safetensors (save_directory / HF_SAFE_WEIGHTS_NAME , tensors )
323325 if serialization == "pytorch" :
324326 assert _has_torch , "`pip install pytorch` to save as bin"
327+ tensors = {"topic_embeddings" : torch .from_numpy (tensors )}
325328 torch .save (tensors , save_directory / HF_WEIGHTS_NAME )
326329
327330
328331def save_ctfidf (model , save_directory : str , serialization : str ):
329332 """Save c-TF-IDF sparse matrix."""
330- indptr = torch .from_numpy (model .c_tf_idf_ .indptr )
331- indices = torch .from_numpy (model .c_tf_idf_ .indices )
332- data = torch .from_numpy (model .c_tf_idf_ .data )
333- shape = torch .from_numpy (np .array (model .c_tf_idf_ .shape ))
334- diag = torch .from_numpy (np .array (model .ctfidf_model ._idf_diag .data ))
335- tensors = {
336- "indptr" : indptr ,
337- "indices" : indices ,
338- "data" : data ,
339- "shape" : shape ,
340- "diag" : diag ,
341- }
333+ indptr = model .c_tf_idf_ .indptr
334+ indices = model .c_tf_idf_ .indices
335+ data = model .c_tf_idf_ .data
336+ shape = np .array (model .c_tf_idf_ .shape )
337+ diag = np .array (model .ctfidf_model ._idf_diag .data )
342338
343339 if serialization == "safetensors" :
340+ tensors = {
341+ "indptr" : indptr ,
342+ "indices" : indices ,
343+ "data" : data ,
344+ "shape" : shape ,
345+ "diag" : diag ,
346+ }
344347 save_safetensors (save_directory / CTFIDF_SAFE_WEIGHTS_NAME , tensors )
345348 if serialization == "pytorch" :
346349 assert _has_torch , "`pip install pytorch` to save as .bin"
350+ tensors = {
351+ "indptr" : torch .from_numpy (indptr ),
352+ "indices" : torch .from_numpy (indices ),
353+ "data" : torch .from_numpy (data ),
354+ "shape" : torch .from_numpy (shape ),
355+ "diag" : torch .from_numpy (diag ),
356+ }
347357 torch .save (tensors , save_directory / CTFIDF_WEIGHTS_NAME )
348358
349359
@@ -511,20 +521,18 @@ def get_package_versions():
511521def load_safetensors (path ):
512522 """Load safetensors and check whether it is installed."""
513523 try :
514- import safetensors .torch
515- import safetensors
524+ import safetensors .numpy
516525
517- return safetensors .torch .load_file (path , device = "cpu" )
526+ return safetensors .numpy .load_file (path )
518527 except ImportError :
519528 raise ValueError ("`pip install safetensors` to load .safetensors" )
520529
521530
522531def save_safetensors (path , tensors ):
523532 """Save safetensors and check whether it is installed."""
524533 try :
525- import safetensors .torch
526- import safetensors
534+ import safetensors .numpy
527535
528- safetensors .torch .save_file (tensors , path )
536+ safetensors .numpy .save_file (tensors , path )
529537 except ImportError :
530538 raise ValueError ("`pip install safetensors` to save as .safetensors" )
0 commit comments