Skip to content

Commit 097f31c

Browse files
committed
run nox format
Signed-off-by: HenryL27 <[email protected]>
1 parent f5ee769 commit 097f31c

File tree

2 files changed

+72
-53
lines changed

2 files changed

+72
-53
lines changed

opensearch_py_ml/ml_models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
# Any modifications Copyright OpenSearch Contributors. See
66
# GitHub history for details.
77

8+
from .crossencodermodel import CrossEncoderModel
89
from .metrics_correlation.mcorr import MCorr
910
from .sentencetransformermodel import SentenceTransformerModel
10-
from .crossencodermodel import CrossEncoderModel
1111

1212
__all__ = ["SentenceTransformerModel", "MCorr", "CrossEncoderModel"]

opensearch_py_ml/ml_models/crossencodermodel.py

Lines changed: 71 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@
66
# GitHub history for details.
77

88
import json
9-
from opensearch_py_ml.ml_commons import ModelUploader
10-
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
9+
import os
10+
import shutil
1111
from pathlib import Path
1212
from zipfile import ZipFile
13-
import shutil
14-
import os
13+
1514
import requests
1615
import torch
16+
from opensearchpy import OpenSearch
17+
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
18+
19+
from opensearch_py_ml.ml_commons import ModelUploader
1720
from opensearch_py_ml.ml_commons.ml_common_utils import (
1821
_generate_model_content_hash_value,
1922
)
20-
from opensearchpy import OpenSearch
2123

2224

2325
def _fix_tokenizer(max_len: int, path: Path):
@@ -31,8 +33,8 @@ def _fix_tokenizer(max_len: int, path: Path):
3133
"""
3234
with open(Path(path) / "tokenizer.json", "r") as f:
3335
parsed = json.load(f)
34-
if "truncation" not in parsed or parsed['truncation'] is None:
35-
parsed['truncation'] = {
36+
if "truncation" not in parsed or parsed["truncation"] is None:
37+
parsed["truncation"] = {
3638
"direction": "Right",
3739
"max_length": max_len,
3840
"strategy": "LongestFirst",
@@ -46,11 +48,9 @@ class CrossEncoderModel:
4648
"""
4749
Class for configuring and uploading cross encoder models for opensearch
4850
"""
51+
4952
def __init__(
50-
self,
51-
hf_model_id: str,
52-
folder_path: str = None,
53-
overwrite: bool = False
53+
self, hf_model_id: str, folder_path: str = None, overwrite: bool = False
5454
) -> None:
5555
"""
5656
Initialize a new CrossEncoder model from a huggingface id
@@ -72,13 +72,14 @@ def __init__(
7272
self._folder_path = Path(folder_path)
7373

7474
if self._folder_path.exists() and not overwrite:
75-
raise Exception(f"Folder {self._folder_path} already exists. To overwrite it, set `overwrite=True`.")
75+
raise Exception(
76+
f"Folder {self._folder_path} already exists. To overwrite it, set `overwrite=True`."
77+
)
7678

7779
self._hf_model_id = hf_model_id
7880
self._framework = None
7981
self._folder_path.mkdir(parents=True, exist_ok=True)
8082

81-
8283
def zip_model(self, framework: str = "pt") -> Path:
8384
"""
8485
Compiles and zips the model to {self._folder_path}/model.zip
@@ -95,8 +96,9 @@ def zip_model(self, framework: str = "pt") -> Path:
9596
if framework == "onnx":
9697
self._framework = "onnx"
9798
return self._zip_model_onnx()
98-
raise Exception(f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`")
99-
99+
raise Exception(
100+
f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`"
101+
)
100102

101103
def _zip_model_pytorch(self) -> Path:
102104
"""
@@ -109,14 +111,18 @@ def _zip_model_pytorch(self) -> Path:
109111

110112
# bge models don't generate token type ids
111113
if mname.startswith("bge"):
112-
features['token_type_ids'] = torch.zeros_like(features['input_ids'])
114+
features["token_type_ids"] = torch.zeros_like(features["input_ids"])
113115

114116
# compile
115-
compiled = torch.jit.trace(model, example_kwarg_inputs={
116-
'input_ids': features['input_ids'],
117-
'attention_mask': features['attention_mask'],
118-
'token_type_ids': features['token_type_ids']
119-
}, strict=False)
117+
compiled = torch.jit.trace(
118+
model,
119+
example_kwarg_inputs={
120+
"input_ids": features["input_ids"],
121+
"attention_mask": features["attention_mask"],
122+
"token_type_ids": features["token_type_ids"],
123+
},
124+
strict=False,
125+
)
120126
torch.jit.save(compiled, f"/tmp/{mname}.pt")
121127

122128
# save tokenizer file
@@ -125,7 +131,9 @@ def _zip_model_pytorch(self) -> Path:
125131
_fix_tokenizer(tk.model_max_length, tk_path)
126132

127133
# get apache license
128-
r = requests.get("https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE")
134+
r = requests.get(
135+
"https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE"
136+
)
129137
with ZipFile(self._folder_path / "model.zip", "w") as f:
130138
f.write(f"/tmp/{mname}.pt", arcname=f"{mname}.pt")
131139
f.write(tk_path + "/tokenizer.json", arcname="tokenizer.json")
@@ -147,23 +155,27 @@ def _zip_model_onnx(self):
147155

148156
# bge models don't generate token type ids
149157
if mname.startswith("bge"):
150-
features['token_type_ids'] = torch.zeros_like(features['input_ids'])
158+
features["token_type_ids"] = torch.zeros_like(features["input_ids"])
151159

152160
# export to onnx
153161
onnx_model_path = f"/tmp/{mname}.onnx"
154162
torch.onnx.export(
155163
model=model,
156-
args=(features['input_ids'], features['attention_mask'], features['token_type_ids']),
164+
args=(
165+
features["input_ids"],
166+
features["attention_mask"],
167+
features["token_type_ids"],
168+
),
157169
f=onnx_model_path,
158-
input_names=['input_ids', 'attention_mask', 'token_type_ids'],
159-
output_names=['output'],
170+
input_names=["input_ids", "attention_mask", "token_type_ids"],
171+
output_names=["output"],
160172
dynamic_axes={
161-
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
162-
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
163-
'token_type_ids': {0: 'batch_size', 1: 'sequence_length'},
164-
'output': {0: 'batch_size'}
173+
"input_ids": {0: "batch_size", 1: "sequence_length"},
174+
"attention_mask": {0: "batch_size", 1: "sequence_length"},
175+
"token_type_ids": {0: "batch_size", 1: "sequence_length"},
176+
"output": {0: "batch_size"},
165177
},
166-
verbose=True
178+
verbose=True,
167179
)
168180

169181
# save tokenizer file
@@ -172,7 +184,9 @@ def _zip_model_onnx(self):
172184
_fix_tokenizer(tk.model_max_length, tk_path)
173185

174186
# get apache license
175-
r = requests.get("https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE")
187+
r = requests.get(
188+
"https://github.com/opensearch-project/opensearch-py-ml/raw/main/LICENSE"
189+
)
176190
with ZipFile(self._folder_path / "model.zip", "w") as f:
177191
f.write(onnx_model_path, arcname=f"{mname}.pt")
178192
f.write(tk_path + "/tokenizer.json", arcname="tokenizer.json")
@@ -183,15 +197,14 @@ def _zip_model_onnx(self):
183197
os.remove(onnx_model_path)
184198
return self._folder_path / "model.zip"
185199

186-
187200
def make_model_config_json(
188-
self,
189-
model_name: str = None,
190-
version_number: str = 1,
191-
description: str = None,
192-
all_config: str = None,
193-
model_type: str = None,
194-
verbose: bool = False,
201+
self,
202+
model_name: str = None,
203+
version_number: str = 1,
204+
description: str = None,
205+
all_config: str = None,
206+
model_type: str = None,
207+
verbose: bool = False,
195208
):
196209
"""
197210
Parse from config.json file of pre-trained hugging-face model to generate a ml-commons_model_config.json file.
@@ -223,7 +236,9 @@ def make_model_config_json(
223236
"""
224237
if not (self._folder_path / "model.zip").exists():
225238
raise Exception("Generate the model zip before generating the config")
226-
hash_value = _generate_model_content_hash_value(str(self._folder_path / "model.zip"))
239+
hash_value = _generate_model_content_hash_value(
240+
str(self._folder_path / "model.zip")
241+
)
227242
if model_name is None:
228243
model_name = Path(self._hf_model_id).name
229244
if description is None:
@@ -235,12 +250,11 @@ def make_model_config_json(
235250
model_type = "bert"
236251
model_format = None
237252
if self._framework is not None:
238-
model_format = {
239-
'pt': 'TORCH_SCRIPT',
240-
'onnx': 'ONNX'
241-
}.get(self._framework)
253+
model_format = {"pt": "TORCH_SCRIPT", "onnx": "ONNX"}.get(self._framework)
242254
if model_format is None:
243-
raise Exception("Model format either not found or not supported. Zip the model before generating the config")
255+
raise Exception(
256+
"Model format either not found or not supported. Zip the model before generating the config"
257+
)
244258
model_config_content = {
245259
"name": model_name,
246260
"version": f"1.0.{version_number}",
@@ -253,15 +267,21 @@ def make_model_config_json(
253267
"embedding_dimension": 1,
254268
"framework_type": "huggingface_transformers",
255269
"all_config": all_config,
256-
}
270+
},
257271
}
258272
if verbose:
259273
print(json.dumps(model_config_content, indent=2))
260274
with open(self._folder_path / "config.json", "w") as f:
261275
json.dump(model_config_content, f)
262276
return self._folder_path / "config.json"
263277

264-
def upload(self, client: OpenSearch, framework: str = 'pt', model_group_id: str = "", verbose: bool = False):
278+
def upload(
279+
self,
280+
client: OpenSearch,
281+
framework: str = "pt",
282+
model_group_id: str = "",
283+
verbose: bool = False,
284+
):
265285
"""
266286
Upload the model to OpenSearch
267287
@@ -283,7 +303,6 @@ def upload(self, client: OpenSearch, framework: str = 'pt', model_group_id: str
283303
if not config_path.exists() or gen_cfg:
284304
self.make_model_config_json()
285305
uploader = ModelUploader(client)
286-
uploader._register_model(str(model_path), str(config_path), model_group_id, verbose)
287-
288-
289-
306+
uploader._register_model(
307+
str(model_path), str(config_path), model_group_id, verbose
308+
)

0 commit comments

Comments
 (0)