@@ -79,31 +79,24 @@ def __init__(
7979 self ._hf_model_id = hf_model_id
8080 self ._framework = None
8181 self ._folder_path .mkdir (parents = True , exist_ok = True )
82+ self ._model_zip = None
83+ self ._model_config = None
8284
83- def zip_model (self , framework : str = "pt" ) -> Path :
85+ def zip_model (self , framework : str = "pt" , zip_fname : str = "model.zip" ) -> Path :
8486 """
85- Compiles and zips the model to {self._folder_path}/model.zip
87+ Compiles and zips the model to {self._folder_path}/{zip_fname}
8688
8789 :param framework: one of "pt", "onnx". The framework to zip the model as.
8890 default: "pt"
8991 :type framework: str
92+ :param zip_fname: path to place resulting zip file inside of self._folder_path.
93+ Example: if folder_path is "/tmp/models" and zip_path is "zipped_up.zip" then
94+ the file can be found at "/tmp/models/zipped_up.zip"
95+ Default: "model.zip"
96+ :type zip_fname: str
9097 :return: the path with the zipped model
9198 :rtype: Path
9299 """
93- if framework == "pt" :
94- self ._framework = "pt"
95- return self ._zip_model_pytorch ()
96- if framework == "onnx" :
97- self ._framework = "onnx"
98- return self ._zip_model_onnx ()
99- raise Exception (
100- f"Unrecognized framework { framework } . Accepted values are `pt`, `onnx`"
101- )
102-
103- def _zip_model_pytorch (self ) -> Path :
104- """
105- Compiles the model to TORCHSCRIPT format.
106- """
107100 tk = AutoTokenizer .from_pretrained (self ._hf_model_id )
108101 model = AutoModelForSequenceClassification .from_pretrained (self ._hf_model_id )
109102 features = tk ([["dummy sentence 1" , "dummy sentence 2" ]], return_tensors = "pt" )
@@ -113,60 +106,74 @@ def _zip_model_pytorch(self) -> Path:
113106 if mname .startswith ("bge" ):
114107 features ["token_type_ids" ] = torch .zeros_like (features ["input_ids" ])
115108
116- # compile
117- compiled = torch .jit .trace (
118- model ,
119- example_kwarg_inputs = {
120- "input_ids" : features ["input_ids" ],
121- "attention_mask" : features ["attention_mask" ],
122- "token_type_ids" : features ["token_type_ids" ],
123- },
124- strict = False ,
125- )
126- torch .jit .save (compiled , f"/tmp/{ mname } .pt" )
109+ if framework == "pt" :
110+ self ._framework = "pt"
111+ model_loc = CrossEncoderModel ._trace_pytorch (model , features , mname )
112+ elif framework == "onnx" :
113+ self ._framework = "onnx"
114+ model_loc = CrossEncoderModel ._trace_onnx (model , features , mname )
115+ else :
116+ raise Exception (
117+ f"Unrecognized framework { framework } . Accepted values are `pt`, `onnx`"
118+ )
127119
128120 # save tokenizer file
129- tk_path = f"/tmp/{ mname } -tokenizer"
121+ tk_path = Path ( f"/tmp/{ mname } -tokenizer" )
130122 tk .save_pretrained (tk_path )
131123 _fix_tokenizer (tk .model_max_length , tk_path )
132124
133125 # get apache license
134126 r = requests .get (
135127 "https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE"
136128 )
137- with ZipFile (self ._folder_path / "model.zip" , "w" ) as f :
138- f .write (f"/tmp/{ mname } .pt" , arcname = f"{ mname } .pt" )
139- f .write (tk_path + "/tokenizer.json" , arcname = "tokenizer.json" )
129+ self ._model_zip = self ._folder_path / zip_fname
130+ with ZipFile (self ._model_zip , "w" ) as f :
131+ f .write (model_loc , arcname = model_loc .name )
132+ f .write (tk_path / "tokenizer.json" , arcname = "tokenizer.json" )
140133 f .writestr ("LICENSE" , r .content )
141134
142135 # clean up temp files
143- shutil .rmtree (f"/tmp/ { mname } -tokenizer" )
144- os .remove (f"/tmp/ { mname } .pt" )
145- return self ._folder_path / "model.zip"
136+ shutil .rmtree (tk_path )
137+ os .remove (model_loc )
138+ return self ._model_zip
146139
147- def _zip_model_onnx (self ):
140+ @staticmethod
141+ def _trace_pytorch (model , features , mname ) -> Path :
148142 """
149- Compiles the model to ONNX format.
150- """
151- tk = AutoTokenizer .from_pretrained (self ._hf_model_id )
152- model = AutoModelForSequenceClassification .from_pretrained (self ._hf_model_id )
153- features = tk ([["dummy sentence 1" , "dummy sentence 2" ]], return_tensors = "pt" )
154- mname = Path (self ._hf_model_id ).name
143+ Compiles the model to TORCHSCRIPT format.
155144
156- # bge models don't generate token type ids
157- if mname .startswith ("bge" ):
158- features ["token_type_ids" ] = torch .zeros_like (features ["input_ids" ])
145+ :param features: Model input features
146+ :return: Path to the traced model
147+ """
148+ # compile
149+ compiled = torch .jit .trace (
150+ model ,
151+ example_kwarg_inputs = {
152+ "input_ids" : features ["input_ids" ],
153+ "attention_mask" : features ["attention_mask" ],
154+ "token_type_ids" : features ["token_type_ids" ],
155+ },
156+ strict = False ,
157+ )
158+ save_loc = Path (f"/tmp/{ mname } .pt" )
159+ torch .jit .save (compiled , f"/tmp/{ mname } .pt" )
160+ return save_loc
159161
162+ @staticmethod
163+ def _trace_onnx (model , features , mname ):
164+ """
165+ Compiles the model to ONNX format.
166+ """
160167 # export to onnx
161- onnx_model_path = f"/tmp/{ mname } .onnx"
168+ save_loc = Path ( f"/tmp/{ mname } .onnx" )
162169 torch .onnx .export (
163170 model = model ,
164171 args = (
165172 features ["input_ids" ],
166173 features ["attention_mask" ],
167174 features ["token_type_ids" ],
168175 ),
169- f = onnx_model_path ,
176+ f = str ( save_loc ) ,
170177 input_names = ["input_ids" , "attention_mask" , "token_type_ids" ],
171178 output_names = ["output" ],
172179 dynamic_axes = {
@@ -177,28 +184,11 @@ def _zip_model_onnx(self):
177184 },
178185 verbose = True ,
179186 )
180-
181- # save tokenizer file
182- tk_path = f"/tmp/{ mname } -tokenizer"
183- tk .save_pretrained (tk_path )
184- _fix_tokenizer (tk .model_max_length , tk_path )
185-
186- # get apache license
187- r = requests .get (
188- "https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE"
189- )
190- with ZipFile (self ._folder_path / "model.zip" , "w" ) as f :
191- f .write (onnx_model_path , arcname = f"{ mname } .pt" )
192- f .write (tk_path + "/tokenizer.json" , arcname = "tokenizer.json" )
193- f .writestr ("LICENSE" , r .content )
194-
195- # clean up temp files
196- shutil .rmtree (f"/tmp/{ mname } -tokenizer" )
197- os .remove (onnx_model_path )
198- return self ._folder_path / "model.zip"
187+ return save_loc
199188
200189 def make_model_config_json (
201190 self ,
191+ config_fname : str = "config.json" ,
202192 model_name : str = None ,
203193 version_number : str = 1 ,
204194 description : str = None ,
@@ -210,6 +200,11 @@ def make_model_config_json(
210200 Parse from config.json file of pre-trained hugging-face model to generate a ml-commons_model_config.json file.
211201 If all required fields are given by users, use the given parameters and will skip reading the config.json
212202
203+ :param config_fname:
204+ Optional, File name of model json config file. Default is "config.json".
205+ Controls where the config file generated by this function will appear -
206+ "{self._folder_path}/{config_fname}"
207+ :type config_fname: str
213208 :param model_name:
214209 Optional, The name of the model. If None, default is model id, for example,
215210 'sentence-transformers/msmarco-distilbert-base-tas-b'
@@ -234,11 +229,13 @@ def make_model_config_json(
234229 :return: model config file path. The file path where the model config file is being saved
235230 :rtype: string
236231 """
237- if not (self ._folder_path / "model.zip" ).exists ():
238- raise Exception ("Generate the model zip before generating the config" )
239- hash_value = _generate_model_content_hash_value (
240- str (self ._folder_path / "model.zip" )
241- )
232+ if self ._model_zip is None :
233+ raise Exception (
234+ "No model zip file. Generate the model zip file before generating the config."
235+ )
236+ if not self ._model_zip .exists ():
237+ raise Exception (f"Model zip file { self ._model_zip } could not be found" )
238+ hash_value = _generate_model_content_hash_value (str (self ._model_zip ))
242239 if model_name is None :
243240 model_name = Path (self ._hf_model_id ).name
244241 if description is None :
@@ -269,11 +266,12 @@ def make_model_config_json(
269266 "all_config" : all_config ,
270267 },
271268 }
269+ self ._model_config = self ._folder_path / config_fname
272270 if verbose :
273271 print (json .dumps (model_config_content , indent = 2 ))
274- with open (self ._folder_path / "config.json" , "w" ) as f :
272+ with open (self ._model_config , "w" ) as f :
275273 json .dump (model_config_content , f )
276- return self ._folder_path / "config.json"
274+ return self ._model_config
277275
278276 def upload (
279277 self ,
@@ -294,15 +292,17 @@ def upload(
294292 :param verbose: log a bunch or not
295293 :type verbose: bool
296294 """
297- config_path = self ._folder_path / "config.json"
298- model_path = self ._folder_path / "model.zip"
299295 gen_cfg = False
300- if not model_path .exists () or self ._framework != framework :
296+ if (
297+ self ._model_zip is None
298+ or not self ._model_zip .exists ()
299+ or self ._framework != framework
300+ ):
301301 gen_cfg = True
302302 self .zip_model (framework )
303- if not config_path .exists () or gen_cfg :
303+ if self . _model_config is None or not self . _model_config .exists () or gen_cfg :
304304 self .make_model_config_json ()
305305 uploader = ModelUploader (client )
306306 uploader ._register_model (
307- str (model_path ), str (config_path ), model_group_id , verbose
307+ str (self . _model_zip ), str (self . _model_config ), model_group_id , verbose
308308 )
0 commit comments