Skip to content

Commit 61af8ca

Browse files
committed
condense common zipping logic; allow more configurable file names
Signed-off-by: HenryL27 <[email protected]>
1 parent f6551ef commit 61af8ca

File tree

1 file changed

+77
-77
lines changed

1 file changed

+77
-77
lines changed

opensearch_py_ml/ml_models/crossencodermodel.py

Lines changed: 77 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -79,31 +79,24 @@ def __init__(
7979
self._hf_model_id = hf_model_id
8080
self._framework = None
8181
self._folder_path.mkdir(parents=True, exist_ok=True)
82+
self._model_zip = None
83+
self._model_config = None
8284

83-
def zip_model(self, framework: str = "pt") -> Path:
85+
def zip_model(self, framework: str = "pt", zip_fname: str = "model.zip") -> Path:
8486
"""
85-
Compiles and zips the model to {self._folder_path}/model.zip
87+
Compiles and zips the model to {self._folder_path}/{zip_fname}
8688
8789
:param framework: one of "pt", "onnx". The framework to zip the model as.
8890
default: "pt"
8991
:type framework: str
92+
:param zip_fname: path to place resulting zip file inside of self._folder_path.
93+
Example: if folder_path is "/tmp/models" and zip_path is "zipped_up.zip" then
94+
the file can be found at "/tmp/models/zipped_up.zip"
95+
Default: "model.zip"
96+
:type zip_fname: str
9097
:return: the path with the zipped model
9198
:rtype: Path
9299
"""
93-
if framework == "pt":
94-
self._framework = "pt"
95-
return self._zip_model_pytorch()
96-
if framework == "onnx":
97-
self._framework = "onnx"
98-
return self._zip_model_onnx()
99-
raise Exception(
100-
f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`"
101-
)
102-
103-
def _zip_model_pytorch(self) -> Path:
104-
"""
105-
Compiles the model to TORCHSCRIPT format.
106-
"""
107100
tk = AutoTokenizer.from_pretrained(self._hf_model_id)
108101
model = AutoModelForSequenceClassification.from_pretrained(self._hf_model_id)
109102
features = tk([["dummy sentence 1", "dummy sentence 2"]], return_tensors="pt")
@@ -113,60 +106,74 @@ def _zip_model_pytorch(self) -> Path:
113106
if mname.startswith("bge"):
114107
features["token_type_ids"] = torch.zeros_like(features["input_ids"])
115108

116-
# compile
117-
compiled = torch.jit.trace(
118-
model,
119-
example_kwarg_inputs={
120-
"input_ids": features["input_ids"],
121-
"attention_mask": features["attention_mask"],
122-
"token_type_ids": features["token_type_ids"],
123-
},
124-
strict=False,
125-
)
126-
torch.jit.save(compiled, f"/tmp/{mname}.pt")
109+
if framework == "pt":
110+
self._framework = "pt"
111+
model_loc = CrossEncoderModel._trace_pytorch(model, features, mname)
112+
elif framework == "onnx":
113+
self._framework = "onnx"
114+
model_loc = CrossEncoderModel._trace_onnx(model, features, mname)
115+
else:
116+
raise Exception(
117+
f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`"
118+
)
127119

128120
# save tokenizer file
129-
tk_path = f"/tmp/{mname}-tokenizer"
121+
tk_path = Path(f"/tmp/{mname}-tokenizer")
130122
tk.save_pretrained(tk_path)
131123
_fix_tokenizer(tk.model_max_length, tk_path)
132124

133125
# get apache license
134126
r = requests.get(
135127
"https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE"
136128
)
137-
with ZipFile(self._folder_path / "model.zip", "w") as f:
138-
f.write(f"/tmp/{mname}.pt", arcname=f"{mname}.pt")
139-
f.write(tk_path + "/tokenizer.json", arcname="tokenizer.json")
129+
self._model_zip = self._folder_path / zip_fname
130+
with ZipFile(self._model_zip, "w") as f:
131+
f.write(model_loc, arcname=model_loc.name)
132+
f.write(tk_path / "tokenizer.json", arcname="tokenizer.json")
140133
f.writestr("LICENSE", r.content)
141134

142135
# clean up temp files
143-
shutil.rmtree(f"/tmp/{mname}-tokenizer")
144-
os.remove(f"/tmp/{mname}.pt")
145-
return self._folder_path / "model.zip"
136+
shutil.rmtree(tk_path)
137+
os.remove(model_loc)
138+
return self._model_zip
146139

147-
def _zip_model_onnx(self):
140+
@staticmethod
141+
def _trace_pytorch(model, features, mname) -> Path:
148142
"""
149-
Compiles the model to ONNX format.
150-
"""
151-
tk = AutoTokenizer.from_pretrained(self._hf_model_id)
152-
model = AutoModelForSequenceClassification.from_pretrained(self._hf_model_id)
153-
features = tk([["dummy sentence 1", "dummy sentence 2"]], return_tensors="pt")
154-
mname = Path(self._hf_model_id).name
143+
Compiles the model to TORCHSCRIPT format.
155144
156-
# bge models don't generate token type ids
157-
if mname.startswith("bge"):
158-
features["token_type_ids"] = torch.zeros_like(features["input_ids"])
145+
:param features: Model input features
146+
:return: Path to the traced model
147+
"""
148+
# compile
149+
compiled = torch.jit.trace(
150+
model,
151+
example_kwarg_inputs={
152+
"input_ids": features["input_ids"],
153+
"attention_mask": features["attention_mask"],
154+
"token_type_ids": features["token_type_ids"],
155+
},
156+
strict=False,
157+
)
158+
save_loc = Path(f"/tmp/{mname}.pt")
159+
torch.jit.save(compiled, f"/tmp/{mname}.pt")
160+
return save_loc
159161

162+
@staticmethod
163+
def _trace_onnx(model, features, mname):
164+
"""
165+
Compiles the model to ONNX format.
166+
"""
160167
# export to onnx
161-
onnx_model_path = f"/tmp/{mname}.onnx"
168+
save_loc = Path(f"/tmp/{mname}.onnx")
162169
torch.onnx.export(
163170
model=model,
164171
args=(
165172
features["input_ids"],
166173
features["attention_mask"],
167174
features["token_type_ids"],
168175
),
169-
f=onnx_model_path,
176+
f=str(save_loc),
170177
input_names=["input_ids", "attention_mask", "token_type_ids"],
171178
output_names=["output"],
172179
dynamic_axes={
@@ -177,28 +184,11 @@ def _zip_model_onnx(self):
177184
},
178185
verbose=True,
179186
)
180-
181-
# save tokenizer file
182-
tk_path = f"/tmp/{mname}-tokenizer"
183-
tk.save_pretrained(tk_path)
184-
_fix_tokenizer(tk.model_max_length, tk_path)
185-
186-
# get apache license
187-
r = requests.get(
188-
"https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE"
189-
)
190-
with ZipFile(self._folder_path / "model.zip", "w") as f:
191-
f.write(onnx_model_path, arcname=f"{mname}.pt")
192-
f.write(tk_path + "/tokenizer.json", arcname="tokenizer.json")
193-
f.writestr("LICENSE", r.content)
194-
195-
# clean up temp files
196-
shutil.rmtree(f"/tmp/{mname}-tokenizer")
197-
os.remove(onnx_model_path)
198-
return self._folder_path / "model.zip"
187+
return save_loc
199188

200189
def make_model_config_json(
201190
self,
191+
config_fname: str = "config.json",
202192
model_name: str = None,
203193
version_number: str = 1,
204194
description: str = None,
@@ -210,6 +200,11 @@ def make_model_config_json(
210200
Parse from config.json file of pre-trained hugging-face model to generate a ml-commons_model_config.json file.
211201
If all required fields are given by users, use the given parameters and will skip reading the config.json
212202
203+
:param config_fname:
204+
Optional, File name of model json config file. Default is "config.json".
205+
Controls where the config file generated by this function will appear -
206+
"{self._folder_path}/{config_fname}"
207+
:type config_fname: str
213208
:param model_name:
214209
Optional, The name of the model. If None, default is model id, for example,
215210
'sentence-transformers/msmarco-distilbert-base-tas-b'
@@ -234,11 +229,13 @@ def make_model_config_json(
234229
:return: model config file path. The file path where the model config file is being saved
235230
:rtype: string
236231
"""
237-
if not (self._folder_path / "model.zip").exists():
238-
raise Exception("Generate the model zip before generating the config")
239-
hash_value = _generate_model_content_hash_value(
240-
str(self._folder_path / "model.zip")
241-
)
232+
if self._model_zip is None:
233+
raise Exception(
234+
"No model zip file. Generate the model zip file before generating the config."
235+
)
236+
if not self._model_zip.exists():
237+
raise Exception(f"Model zip file {self._model_zip} could not be found")
238+
hash_value = _generate_model_content_hash_value(str(self._model_zip))
242239
if model_name is None:
243240
model_name = Path(self._hf_model_id).name
244241
if description is None:
@@ -269,11 +266,12 @@ def make_model_config_json(
269266
"all_config": all_config,
270267
},
271268
}
269+
self._model_config = self._folder_path / config_fname
272270
if verbose:
273271
print(json.dumps(model_config_content, indent=2))
274-
with open(self._folder_path / "config.json", "w") as f:
272+
with open(self._model_config, "w") as f:
275273
json.dump(model_config_content, f)
276-
return self._folder_path / "config.json"
274+
return self._model_config
277275

278276
def upload(
279277
self,
@@ -294,15 +292,17 @@ def upload(
294292
:param verbose: log a bunch or not
295293
:type verbose: bool
296294
"""
297-
config_path = self._folder_path / "config.json"
298-
model_path = self._folder_path / "model.zip"
299295
gen_cfg = False
300-
if not model_path.exists() or self._framework != framework:
296+
if (
297+
self._model_zip is None
298+
or not self._model_zip.exists()
299+
or self._framework != framework
300+
):
301301
gen_cfg = True
302302
self.zip_model(framework)
303-
if not config_path.exists() or gen_cfg:
303+
if self._model_config is None or not self._model_config.exists() or gen_cfg:
304304
self.make_model_config_json()
305305
uploader = ModelUploader(client)
306306
uploader._register_model(
307-
str(model_path), str(config_path), model_group_id, verbose
307+
str(self._model_zip), str(self._model_config), model_group_id, verbose
308308
)

0 commit comments

Comments
 (0)