Skip to content

Commit 077e7e8

Browse files
committed
add simple unit tests for model saving
Signed-off-by: HenryL27 <[email protected]>
1 parent 61af8ca commit 077e7e8

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# The OpenSearch Contributors require contributions made to
3+
# this file be licensed under the Apache-2.0 license or a
4+
# compatible open source license.
5+
# Any modifications Copyright OpenSearch Contributors. See
6+
# GitHub history for details.
7+
import shutil
8+
from pathlib import Path
9+
10+
import pytest
11+
12+
from opensearch_py_ml.ml_models import CrossEncoderModel
13+
from tests.ml_models.test_sentencetransformermodel_pytest import (
14+
compare_model_config,
15+
compare_model_zip_file,
16+
)
17+
18+
TEST_FOLDER = Path(__file__) / "tests" / "test_model_files"
19+
20+
21+
@pytest.fixture(scope="function")
22+
def tinybert() -> CrossEncoderModel:
23+
model = CrossEncoderModel("cross-encoder/ms-marco-TinyBERT-L-2-v2")
24+
yield model
25+
shutil.rmtree(
26+
"/tmp/models/cross-encoder/ms-marco-TinyBert-L-2-v2", ignore_errors=True
27+
)
28+
29+
30+
def test_pt_has_correct_files(tinybert):
31+
zip_path = tinybert.zip_model()
32+
config_path = tinybert.make_model_config_json()
33+
compare_model_zip_file(
34+
zip_file_path=zip_path,
35+
expected_filenames=["ms-marco-TinyBERT-L-2-v2.pt", "tokenizer.json", "LICENSE"],
36+
model_format="TORCH_SCRIPT",
37+
)
38+
compare_model_config(
39+
model_config_path=config_path,
40+
model_id="cross-encoder/ms-marco-TinyBERT-L-2-v2",
41+
model_format="TORCH_SCRIPT",
42+
expected_model_description={
43+
"model_type": "bert",
44+
"embedding_dimension": 1,
45+
"framework_type": "huggingface_transformers",
46+
},
47+
)
48+
49+
50+
def test_onnx_has_correct_files(tinybert):
51+
zip_path = tinybert.zip_model(framework="onnx")
52+
config_path = tinybert.make_model_config_json()
53+
compare_model_zip_file(
54+
zip_file_path=zip_path,
55+
expected_filenames=[
56+
"ms-marco-TinyBERT-L-2-v2.onnx",
57+
"tokenizer.json",
58+
"LICENSE",
59+
],
60+
model_format="ONNX",
61+
)
62+
compare_model_config(
63+
model_config_path=config_path,
64+
model_id="cross-encoder/ms-marco-TinyBERT-L-2-v2",
65+
model_format="ONNX",
66+
expected_model_description={
67+
"model_type": "bert",
68+
"embedding_dimension": 1,
69+
"framework_type": "huggingface_transformers",
70+
},
71+
)
72+
73+
74+
def test_can_pick_names_for_files(tinybert):
75+
zip_path = tinybert.zip_model(framework="onnx", zip_fname="funky-model-filename.pt")
76+
config_path = tinybert.make_model_config_json(
77+
config_fname="funky-model-config.json"
78+
)
79+
compare_model_zip_file(
80+
zip_file_path=zip_path,
81+
expected_filenames=["funky-model-filename.pt", "tokenizer.json", "LICENSE"],
82+
model_format="TORCH_SCRIPT",
83+
)
84+
compare_model_config(
85+
model_config_path=config_path,
86+
model_id="cross-encoder/ms-marco-TinyBERT-L-2-v2",
87+
model_format="TORCH_SCRIPT",
88+
expected_model_description={
89+
"model_type": "bert",
90+
"embedding_dimension": 1,
91+
"framework_type": "huggingface_transformers",
92+
},
93+
)
94+
assert config_path.endswith("funky-model-config.json")

0 commit comments

Comments
 (0)