11import sys
2- from typing import Any
32from io import BytesIO
43from pathlib import Path
54import srsly
1514from transformers import AutoModel , AutoConfig , AutoTokenizer
1615
1716
17+ def override_hf_shims_to_bytes ():
18+ assert hf_shim .HFShim .to_bytes is not HFShimCustom .to_bytes
19+ origin = hf_shim .HFShim .to_bytes
20+ hf_shim .HFShim .to_bytes = HFShimCustom .to_bytes
21+ return origin
22+
23+ def recover_hf_shims_to_bytes (origin ):
24+ assert hf_shim .HFShim .to_bytes is HFShimCustom .to_bytes
25+ hf_shim .HFShim .to_bytes = origin
26+
27+
1828def override_hf_shims_from_bytes ():
1929 assert hf_shim .HFShim .from_bytes is not HFShimCustom .from_bytes
2030 origin = hf_shim .HFShim .from_bytes
@@ -28,6 +38,44 @@ def recover_hf_shims_from_bytes(origin):
2838
2939class HFShimCustom (HFShim ):
3040
41+ def to_bytes (self ):
42+ config = {}
43+ tok_dict = {}
44+ # weights_bytes = {}
45+ tok_cfg = {}
46+ trf_cfg = {}
47+ hf_model = self ._hfmodel
48+ if hf_model .transformer is not None :
49+ tok_dict = {}
50+ config = hf_model .transformer .config .to_dict ()
51+ tokenizer = hf_model .tokenizer
52+ with make_tempdir () as temp_dir :
53+ if hasattr (tokenizer , "vocab_file" ):
54+ vocab_file_name = tokenizer .vocab_files_names ["vocab_file" ]
55+ vocab_file_path = str ((temp_dir / vocab_file_name ).absolute ())
56+ with open (vocab_file_path , "wb" ) as fileh :
57+ fileh .write (hf_model .vocab_file_contents )
58+ tokenizer .vocab_file = vocab_file_path
59+ tokenizer .save_pretrained (str (temp_dir .absolute ()))
60+ for x in temp_dir .glob ("**/*" ):
61+ if x .is_file ():
62+ tok_dict [x .name ] = x .read_bytes ()
63+ filelike = BytesIO ()
64+ torch .save (self ._model .state_dict (), filelike )
65+ filelike .seek (0 )
66+ # weights_bytes = filelike.getvalue()
67+ else :
68+ tok_cfg = hf_model ._init_tokenizer_config
69+ trf_cfg = hf_model ._init_transformer_config
70+ msg = {
71+ "config" : config ,
72+ # "state": weights_bytes,
73+ "tokenizer" : tok_dict ,
74+ "_init_tokenizer_config" : tok_cfg ,
75+ "_init_transformer_config" : trf_cfg ,
76+ }
77+ return srsly .msgpack_dumps (msg )
78+
3179 def from_bytes (self , bytes_data ):
3280 msg = srsly .msgpack_loads (bytes_data )
3381 config_dict = msg ["config" ]
@@ -62,34 +110,35 @@ def from_bytes(self, bytes_data):
62110 with open (vocab_file_path , "rb" ) as fileh :
63111 vocab_file_contents = fileh .read ()
64112
65- try :
113+ ops = get_current_ops ()
114+ if ops .device_type == "cpu" :
115+ map_location = "cpu"
116+ else : # pragma: no cover
117+ device_id = torch .cuda .current_device ()
118+ map_location = f"cuda:{ device_id } "
119+
120+ if "state" in msg :
66121 transformer = AutoModel .from_config (config )
67- except OSError as e :
122+ filelike = BytesIO (msg ["state" ])
123+ filelike .seek (0 )
124+ transformer .load_state_dict (torch .load (filelike , map_location = map_location ))
125+ else :
68126 try :
69- transformer = AutoModel .from_pretrained (config [ " _name_or_path" ] , local_files_only = True )
127+ transformer = AutoModel .from_pretrained (config . _name_or_path , local_files_only = True )
70128 except OSError as e2 :
71- print ("trying to download model from huggingface hub:" , config [ " _name_or_path" ] , "..." , file = sys .stderr )
72- transformer = AutoModel .from_pretrained (config [ " _name_or_path" ] )
129+ print ("trying to download model from huggingface hub:" , config . _name_or_path , "..." , file = sys .stderr )
130+ transformer = AutoModel .from_pretrained (config . _name_or_path )
73131 print ("succeded" , file = sys .stderr )
74132
133+ transformer .to (map_location )
134+ self ._model = transformer
75135 self ._hfmodel = HFObjects (
76136 tokenizer ,
77137 transformer ,
78138 vocab_file_contents ,
79139 SimpleFrozenDict (),
80140 SimpleFrozenDict (),
81141 )
82- self ._model = transformer
83- filelike = BytesIO (msg ["state" ])
84- filelike .seek (0 )
85- ops = get_current_ops ()
86- if ops .device_type == "cpu" :
87- map_location = "cpu"
88- else : # pragma: no cover
89- device_id = torch .cuda .current_device ()
90- map_location = f"cuda:{ device_id } "
91- self ._model .load_state_dict (torch .load (filelike , map_location = map_location ))
92- self ._model .to (map_location )
93142 else :
94143 self ._hfmodel = HFObjects (
95144 None ,
0 commit comments