@@ -419,7 +419,7 @@ def train_model(
419419 """
420420
421421 if model_id is None :
422- model_id = "sentence-transformers/msmarco-distilbert-base-tas-b"
422+ model_id = self . model_id
423423 if output_model_name is None :
424424 output_model_name = str (self .model_id .split ("/" )[- 1 ] + ".pt" )
425425
@@ -765,7 +765,7 @@ def _fill_null_truncation_field(
765765 def save_as_pt (
766766 self ,
767767 sentences : [str ],
768- model_id = "sentence-transformers/msmarco-distilbert-base-tas-b" ,
768+ model_id : str = None ,
769769 model_name : str = None ,
770770 save_json_folder_path : str = None ,
771771 model_output_path : str = None ,
@@ -806,6 +806,9 @@ def save_as_pt(
806806 :rtype: string
807807 """
808808
809+ if model_id is None :
810+ model_id = self .model_id
811+
809812 model = SentenceTransformer (model_id )
810813
811814 if model_name is None :
@@ -877,7 +880,7 @@ def save_as_pt(
877880
878881 def save_as_onnx (
879882 self ,
880- model_id = "sentence-transformers/msmarco-distilbert-base-tas-b" ,
883+ model_id : str = None ,
881884 model_name : str = None ,
882885 save_json_folder_path : str = None ,
883886 model_output_path : str = None ,
@@ -915,6 +918,9 @@ def save_as_onnx(
915918 :rtype: string
916919 """
917920
921+ if model_id is None :
922+ model_id = self .model_id
923+
918924 model = SentenceTransformer (model_id )
919925
920926 if model_name is None :
@@ -1143,6 +1149,7 @@ def make_model_config_json(
11431149 model_name : str = None ,
11441150 version_number : str = 1 ,
11451151 model_format : str = "TORCH_SCRIPT" ,
1152+ config_out_path : str = None ,
11461153 model_zip_file_path : str = None ,
11471154 embedding_dimension : int = None ,
11481155 pooling_mode : str = None ,
@@ -1163,6 +1170,10 @@ def make_model_config_json(
11631170 :param model_format:
11641171 Optional, the format of the model. Default is "TORCH_SCRIPT".
11651172 :type model_format: string
1173+ :param config_output_path:
1174+ Optional, path to save model config json file. If None, default as
1175+ default_folder_path from the constructor
1176+ :type config_output_path: string
11661177 :param model_zip_file_path:
11671178 Optional, path to the model zip file. Default is the zip file path used in save_as_pt or save_as_onnx
11681179 depending on model_format. This zip file is used to compute model_content_size_in_bytes and
@@ -1198,6 +1209,8 @@ def make_model_config_json(
11981209 :rtype: string
11991210 """
12001211 folder_path = self .folder_path
1212+ if config_output_path is None :
1213+ config_output_path = self .folder_path
12011214 config_json_file_path = os .path .join (folder_path , "config.json" )
12021215 if model_name is None :
12031216 model_name = self .model_id
@@ -1313,7 +1326,7 @@ def make_model_config_json(
13131326 print (json .dumps (model_config_content , indent = 4 ))
13141327
13151328 model_config_file_path = os .path .join (
1316- folder_path , "ml-commons_model_config.json"
1329+ config_output_path , "ml-commons_model_config.json"
13171330 )
13181331 os .makedirs (os .path .dirname (model_config_file_path ), exist_ok = True )
13191332 with open (model_config_file_path , "w" ) as file :
0 commit comments