Skip to content

Commit 8b43bc3

Browse files
committed
feat: Add safetensors module for loading model file
1 parent cffa036 commit 8b43bc3

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

examples/research_projects/anytext/recognizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.nn.functional as F
1414
from easydict import EasyDict as edict
1515
from ocr_recog.RecModel import RecModel
16+
from safetensors.torch import load_file
1617
from skimage.transform._geometric import _umeyama as get_sym_mat
1718

1819

@@ -105,7 +106,7 @@ def create_predictor(model_dir=None, model_lang="ch", is_onnx=False):
105106

106107
rec_model = RecModel(rec_config)
107108
if model_file_path is not None:
108-
rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
109+
rec_model.load_state_dict(load_file(model_file_path))
109110
rec_model.eval()
110111
return rec_model.eval()
111112

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, use_fp16):
3535
# TODO: Understand the reason of param.requires_grad = True
3636
for param in self.embedding_manager.embedding_parameters():
3737
param.requires_grad = True
38-
rec_model_dir = "ppv3_rec.pth"
38+
rec_model_dir = "ppv3_rec.safetensors"
3939
self.text_predictor = create_predictor(rec_model_dir).eval()
4040
args = {}
4141
args["rec_image_shape"] = "3, 48, 320"

0 commit comments

Comments
 (0)