Skip to content

Commit b2bc31e

Browse files
JebqJean-Baptiste Oger
authored andcommitted
test: update tests
Signed-off-by: Jean-Baptiste Oger <[email protected]>
1 parent ee7b3d1 commit b2bc31e

File tree

2 files changed

+111
-16
lines changed

2 files changed

+111
-16
lines changed

opensearch_py_ml/ml_models/sentencetransformermodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def save_as_pt(
780780
Required, for example sentences = ['today is sunny']
781781
:type sentences: List of string [str]
782782
:param model_id:
783-
sentence transformer model id to download model from sentence transformers.
783+
Optional, sentence transformer model id to download model from sentence transformers.
784784
default model_id = "sentence-transformers/msmarco-distilbert-base-tas-b"
785785
:type model_id: string
786786
:param model_name:
@@ -892,7 +892,7 @@ def save_as_onnx(
892892
zip the model file and its tokenizer.json file to prepare to upload to the Open Search cluster
893893
894894
:param model_id:
895-
sentence transformer model id to download model from sentence transformers.
895+
Optional, sentence transformer model id to download model from sentence transformers.
896896
default model_id = "sentence-transformers/msmarco-distilbert-base-tas-b"
897897
:type model_id: string
898898
:param model_name:

tests/ml_models/test_sentencetransformermodel_pytest.py

Lines changed: 109 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def test_make_model_config_json_for_torch_script():
251251
model_id=model_id,
252252
)
253253

254-
test_model5.save_as_pt(model_id=model_id, sentences=["today is sunny"])
254+
test_model5.save_as_pt(sentences=["today is sunny"])
255255
model_config_path_torch = test_model5.make_model_config_json(
256256
model_format="TORCH_SCRIPT", verbose=True
257257
)
@@ -267,6 +267,38 @@ def test_make_model_config_json_for_torch_script():
267267
clean_test_folder(TEST_FOLDER)
268268

269269

270+
def test_make_model_config_json_set_path_for_torch_script():
271+
model_id = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
272+
model_format = "TORCH_SCRIPT"
273+
expected_model_description = "This is a sentence-transformers model: It maps sentences & paragraphs to a 384 dimensional dense vector space and was designed for semantic search. It has been trained on 215M pairs from diverse sources."
274+
expected_model_config_data = {
275+
"embedding_dimension": 384,
276+
"pooling_mode": "MEAN",
277+
"normalize_result": True,
278+
}
279+
280+
clean_test_folder(TEST_FOLDER)
281+
test_model5 = SentenceTransformerModel(
282+
folder_path=TEST_FOLDER,
283+
model_id=model_id,
284+
)
285+
286+
test_model5.save_as_pt(sentences=["today is sunny"])
287+
model_config_path_torch = test_model5.make_model_config_json(
288+
config_output_path=TEST_FOLDER, model_format="TORCH_SCRIPT", verbose=True
289+
)
290+
291+
compare_model_config(
292+
model_config_path_torch,
293+
model_id,
294+
model_format,
295+
expected_model_description=expected_model_description,
296+
expected_model_config_data=expected_model_config_data,
297+
)
298+
299+
clean_test_folder(TEST_FOLDER)
300+
301+
270302
def test_make_model_config_json_for_onnx():
271303
model_id = "sentence-transformers/paraphrase-MiniLM-L3-v2"
272304
model_format = "ONNX"
@@ -283,7 +315,7 @@ def test_make_model_config_json_for_onnx():
283315
model_id=model_id,
284316
)
285317

286-
test_model6.save_as_onnx(model_id=model_id)
318+
test_model6.save_as_onnx()
287319
model_config_path_onnx = test_model6.make_model_config_json(model_format="ONNX")
288320

289321
compare_model_config(
@@ -297,6 +329,38 @@ def test_make_model_config_json_for_onnx():
297329
clean_test_folder(TEST_FOLDER)
298330

299331

332+
def test_make_model_config_json_set_path_for_onnx():
333+
model_id = "sentence-transformers/paraphrase-MiniLM-L3-v2"
334+
model_format = "ONNX"
335+
expected_model_description = "This is a sentence-transformers model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search."
336+
expected_model_config_data = {
337+
"embedding_dimension": 384,
338+
"pooling_mode": "MEAN",
339+
"normalize_result": False,
340+
}
341+
342+
clean_test_folder(TEST_FOLDER)
343+
test_model6 = SentenceTransformerModel(
344+
folder_path=TEST_FOLDER,
345+
model_id=model_id,
346+
)
347+
348+
test_model6.save_as_onnx()
349+
model_config_path_onnx = test_model6.make_model_config_json(
350+
config_output_path=TEST_FOLDER, model_format="ONNX"
351+
)
352+
353+
compare_model_config(
354+
model_config_path_onnx,
355+
model_id,
356+
model_format,
357+
expected_model_description=expected_model_description,
358+
expected_model_config_data=expected_model_config_data,
359+
)
360+
361+
clean_test_folder(TEST_FOLDER)
362+
363+
300364
def test_overwrite_fields_in_model_config():
301365
model_id = "sentence-transformers/all-distilroberta-v1"
302366
model_format = "TORCH_SCRIPT"
@@ -318,7 +382,7 @@ def test_overwrite_fields_in_model_config():
318382
model_id=model_id,
319383
)
320384

321-
test_model7.save_as_pt(model_id=model_id, sentences=["today is sunny"])
385+
test_model7.save_as_pt(sentences=["today is sunny"])
322386
model_config_path_torch = test_model7.make_model_config_json(
323387
model_format="TORCH_SCRIPT"
324388
)
@@ -337,7 +401,7 @@ def test_overwrite_fields_in_model_config():
337401
model_id=model_id,
338402
)
339403

340-
test_model8.save_as_pt(model_id=model_id, sentences=["today is sunny"])
404+
test_model8.save_as_pt(sentences=["today is sunny"])
341405
model_config_path_torch = test_model8.make_model_config_json(
342406
model_format="TORCH_SCRIPT",
343407
embedding_dimension=overwritten_model_config_data["embedding_dimension"],
@@ -367,7 +431,7 @@ def test_missing_readme_md_file():
367431
model_id=model_id,
368432
)
369433

370-
test_model9.save_as_pt(model_id=model_id, sentences=["today is sunny"])
434+
test_model9.save_as_pt(sentences=["today is sunny"])
371435
temp_path = os.path.join(
372436
TEST_FOLDER,
373437
"README.md",
@@ -403,7 +467,7 @@ def test_missing_expected_description_in_readme_file():
403467
model_id=model_id,
404468
)
405469

406-
test_model10.save_as_pt(model_id=model_id, sentences=["today is sunny"])
470+
test_model10.save_as_pt(sentences=["today is sunny"])
407471
temp_path = os.path.join(
408472
TEST_FOLDER,
409473
"README.md",
@@ -440,7 +504,7 @@ def test_overwrite_description():
440504
model_id=model_id,
441505
)
442506

443-
test_model11.save_as_pt(model_id=model_id, sentences=["today is sunny"])
507+
test_model11.save_as_pt(sentences=["today is sunny"])
444508
model_config_path_torch = test_model11.make_model_config_json(
445509
model_format=model_format, description=expected_model_description
446510
)
@@ -471,7 +535,7 @@ def test_long_description():
471535
model_id=model_id,
472536
)
473537

474-
test_model12.save_as_pt(model_id=model_id, sentences=["today is sunny"])
538+
test_model12.save_as_pt(sentences=["today is sunny"])
475539
model_config_path_torch = test_model12.make_model_config_json(
476540
model_format=model_format
477541
)
@@ -501,7 +565,7 @@ def test_truncation_parameter():
501565
model_id=model_id,
502566
)
503567

504-
test_model13.save_as_pt(model_id=model_id, sentences=["today is sunny"])
568+
test_model13.save_as_pt(sentences=["today is sunny"])
505569

506570
tokenizer_json_file_path = os.path.join(TEST_FOLDER, "tokenizer.json")
507571
try:
@@ -534,7 +598,7 @@ def test_undefined_model_max_length_in_tokenizer_for_torch_script():
534598
model_id=model_id,
535599
)
536600

537-
test_model14.save_as_pt(model_id=model_id, sentences=["today is sunny"])
601+
test_model14.save_as_pt(sentences=["today is sunny"])
538602

539603
tokenizer_json_file_path = os.path.join(TEST_FOLDER, "tokenizer.json")
540604
try:
@@ -563,7 +627,7 @@ def test_undefined_model_max_length_in_tokenizer_for_onnx():
563627
model_id=model_id,
564628
)
565629

566-
test_model14.save_as_onnx(model_id=model_id)
630+
test_model14.save_as_onnx()
567631

568632
tokenizer_json_file_path = os.path.join(TEST_FOLDER, "tokenizer.json")
569633
try:
@@ -598,7 +662,6 @@ def test_save_as_pt_with_license():
598662
)
599663

600664
test_model15.save_as_pt(
601-
model_id=model_id,
602665
sentences=["today is sunny"],
603666
add_apache_license=True,
604667
)
@@ -622,7 +685,7 @@ def test_save_as_onnx_with_license():
622685
model_id=model_id,
623686
)
624687

625-
test_model16.save_as_onnx(model_id=model_id, add_apache_license=True)
688+
test_model16.save_as_onnx(add_apache_license=True)
626689

627690
compare_model_zip_file(onnx_zip_file_path, onnx_expected_filenames, model_format)
628691

@@ -649,7 +712,7 @@ def test_zip_model_with_license():
649712
model_id=model_id,
650713
)
651714

652-
test_model17.save_as_pt(model_id=model_id, sentences=["today is sunny"])
715+
test_model17.save_as_pt(sentences=["today is sunny"])
653716
compare_model_zip_file(zip_file_path, expected_filenames_wo_license, model_format)
654717

655718
test_model17.zip_model(add_apache_license=True)
@@ -658,5 +721,37 @@ def test_zip_model_with_license():
658721
clean_test_folder(TEST_FOLDER)
659722

660723

724+
def test_save_as_pt_model_with_different_id():
725+
model_id = "sentence-transformers/msmarco-distilbert-base-tas-b"
726+
model_id2 = "sentence-transformers/all-MiniLM-L6-v2"
727+
model_format = "TORCH_SCRIPT"
728+
zip_file_path = os.path.join(TEST_FOLDER, "msmarco-distilbert-base-tas-b.zip")
729+
zip_file_path2 = os.path.join(TEST_FOLDER, "all-MiniLM-L6-v2")
730+
expected_filenames_wo_model_id = {
731+
"msmarco-distilbert-base-tas-b.pt",
732+
"tokenizer.json",
733+
}
734+
expected_filenames_with_model_id = {
735+
"msmarco-distilbert-base-tas-b.pt",
736+
"tokenizer.json",
737+
}
738+
739+
clean_test_folder(TEST_FOLDER)
740+
test_model17 = SentenceTransformerModel(
741+
folder_path=TEST_FOLDER,
742+
model_id=model_id,
743+
)
744+
745+
test_model17.save_as_pt(sentences=["today is sunny"])
746+
compare_model_zip_file(zip_file_path, expected_filenames_wo_model_id, model_format)
747+
748+
test_model17.save_as_pt(model_id=model_id2, sentences=["today is sunny"])
749+
compare_model_zip_file(
750+
zip_file_path2, expected_filenames_with_model_id, model_format
751+
)
752+
753+
clean_test_folder(TEST_FOLDER)
754+
755+
661756
clean_test_folder(TEST_FOLDER)
662757
clean_test_folder(TESTDATA_UNZIP_FOLDER)

0 commit comments

Comments
 (0)