Skip to content

Commit f6ae513

Browse files
JebqJean-Baptiste Oger
authored andcommitted
refactor: use self.model_id where applicable
Signed-off-by: Jean-Baptiste Oger <[email protected]>
1 parent e0b1bcf commit f6ae513

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

opensearch_py_ml/ml_models/sentencetransformermodel.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def train_model(
419419
"""
420420

421421
if model_id is None:
422-
model_id = "sentence-transformers/msmarco-distilbert-base-tas-b"
422+
model_id = self.model_id
423423
if output_model_name is None:
424424
output_model_name = str(self.model_id.split("/")[-1] + ".pt")
425425

@@ -765,7 +765,7 @@ def _fill_null_truncation_field(
765765
def save_as_pt(
766766
self,
767767
sentences: [str],
768-
model_id="sentence-transformers/msmarco-distilbert-base-tas-b",
768+
model_id: str = None,
769769
model_name: str = None,
770770
save_json_folder_path: str = None,
771771
model_output_path: str = None,
@@ -806,6 +806,9 @@ def save_as_pt(
806806
:rtype: string
807807
"""
808808

809+
if model_id is None:
810+
model_id = self.model_id
811+
809812
model = SentenceTransformer(model_id)
810813

811814
if model_name is None:
@@ -877,7 +880,7 @@ def save_as_pt(
877880

878881
def save_as_onnx(
879882
self,
880-
model_id="sentence-transformers/msmarco-distilbert-base-tas-b",
883+
model_id: str = None,
881884
model_name: str = None,
882885
save_json_folder_path: str = None,
883886
model_output_path: str = None,
@@ -915,6 +918,9 @@ def save_as_onnx(
915918
:rtype: string
916919
"""
917920

921+
if model_id is None:
922+
model_id = self.model_id
923+
918924
model = SentenceTransformer(model_id)
919925

920926
if model_name is None:
@@ -1143,6 +1149,7 @@ def make_model_config_json(
11431149
model_name: str = None,
11441150
version_number: str = 1,
11451151
model_format: str = "TORCH_SCRIPT",
1152+
config_out_path: str = None,
11461153
model_zip_file_path: str = None,
11471154
embedding_dimension: int = None,
11481155
pooling_mode: str = None,
@@ -1163,6 +1170,10 @@ def make_model_config_json(
11631170
:param model_format:
11641171
Optional, the format of the model. Default is "TORCH_SCRIPT".
11651172
:type model_format: string
1173+
:param config_output_path:
1174+
Optional, path to save model config json file. If None, default as
1175+
default_folder_path from the constructor
1176+
:type config_output_path: string
11661177
:param model_zip_file_path:
11671178
Optional, path to the model zip file. Default is the zip file path used in save_as_pt or save_as_onnx
11681179
depending on model_format. This zip file is used to compute model_content_size_in_bytes and
@@ -1198,6 +1209,8 @@ def make_model_config_json(
11981209
:rtype: string
11991210
"""
12001211
folder_path = self.folder_path
1212+
if config_output_path is None:
1213+
config_output_path = self.folder_path
12011214
config_json_file_path = os.path.join(folder_path, "config.json")
12021215
if model_name is None:
12031216
model_name = self.model_id
@@ -1313,7 +1326,7 @@ def make_model_config_json(
13131326
print(json.dumps(model_config_content, indent=4))
13141327

13151328
model_config_file_path = os.path.join(
1316-
folder_path, "ml-commons_model_config.json"
1329+
config_output_path, "ml-commons_model_config.json"
13171330
)
13181331
os.makedirs(os.path.dirname(model_config_file_path), exist_ok=True)
13191332
with open(model_config_file_path, "w") as file:

0 commit comments

Comments
 (0)