Skip to content

Commit 210e245

Browse files
committed
change framework name pt -> torch_script
Signed-off-by: HenryL27 <[email protected]>
1 parent 8b0d263 commit 210e245

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

opensearch_py_ml/ml_models/crossencodermodel.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,14 @@ def __init__(
8282
self._model_zip = None
8383
self._model_config = None
8484

85-
def zip_model(self, framework: str = "pt", zip_fname: str = "model.zip") -> Path:
85+
def zip_model(
86+
self, framework: str = "torch_script", zip_fname: str = "model.zip"
87+
) -> Path:
8688
"""
8789
Compiles and zips the model to {self._folder_path}/{zip_fname}
8890
89-
:param framework: one of "pt", "onnx". The framework to zip the model as.
90-
default: "pt"
91+
:param framework: one of "torch_script", "onnx". The framework to zip the model as.
92+
default: "torch_script"
9193
:type framework: str
9294
:param zip_fname: path to place resulting zip file inside of self._folder_path.
9395
Example: if folder_path is "/tmp/models" and zip_path is "zipped_up.zip" then
@@ -106,15 +108,15 @@ def zip_model(self, framework: str = "pt", zip_fname: str = "model.zip") -> Path
106108
if mname.startswith("bge"):
107109
features["token_type_ids"] = torch.zeros_like(features["input_ids"])
108110

109-
if framework == "pt":
110-
self._framework = "pt"
111+
if framework == "torch_script":
112+
self._framework = "torch_script"
111113
model_loc = CrossEncoderModel._trace_pytorch(model, features, mname)
112114
elif framework == "onnx":
113115
self._framework = "onnx"
114116
model_loc = CrossEncoderModel._trace_onnx(model, features, mname)
115117
else:
116118
raise Exception(
117-
f"Unrecognized framework {framework}. Accepted values are `pt`, `onnx`"
119+
f"Unrecognized framework {framework}. Accepted values are `torch_script`, `onnx`"
118120
)
119121

120122
# save tokenizer file
@@ -258,7 +260,9 @@ def make_model_config_json(
258260
model_type = "bert"
259261
model_format = None
260262
if self._framework is not None:
261-
model_format = {"pt": "TORCH_SCRIPT", "onnx": "ONNX"}.get(self._framework)
263+
model_format = {"torch_script": "TORCH_SCRIPT", "onnx": "ONNX"}.get(
264+
self._framework
265+
)
262266
if model_format is None:
263267
raise Exception(
264268
"Model format either not found or not supported. Zip the model before generating the config"
@@ -287,7 +291,7 @@ def make_model_config_json(
287291
def upload(
288292
self,
289293
client: OpenSearch,
290-
framework: str = "pt",
294+
framework: str = "torch_script",
291295
model_group_id: str = "",
292296
verbose: bool = False,
293297
):
@@ -296,7 +300,7 @@ def upload(
296300
297301
:param client: OpenSearch client
298302
:type client: OpenSearch
299-
:param framework: either 'pt' or 'onnx'
303+
:param framework: either 'torch_script' or 'onnx'
300304
:type framework: str
301305
:param model_group_id: model group id to upload this model to
302306
:type model_group_id: str

tests/ml_models/test_crossencodermodel_pytest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def test_onnx_has_correct_files(tinybert):
7272

7373

7474
def test_can_pick_names_for_files(tinybert):
75-
zip_path = tinybert.zip_model(framework="onnx", zip_fname="funky-model-filename.pt")
75+
zip_path = tinybert.zip_model(
76+
framework="torch_script", zip_fname="funky-model-filename.pt"
77+
)
7678
config_path = tinybert.make_model_config_json(
7779
config_fname="funky-model-config.json"
7880
)

0 commit comments

Comments
 (0)