Skip to content

Commit 94d555b

Browse files
committed
Fix funtional tests
1 parent f1a9157 commit 94d555b

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

tests/python/functional/test_save.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,60 +12,58 @@
1212
from model_api.adapters.utils import load_parameters_from_onnx
1313

1414

15-
def test_detector_save(tmp_path):
16-
downloaded = Model.create_model(
17-
"ssd_mobilenet_v1_fpn_coco",
18-
configuration={"mean_values": [0, 0, 0], "confidence_threshold": 0.6},
15+
def test_detector_save(tmp_path, data):
16+
detector = Model.create_model(
17+
Path(data) / "otx_models/detection_model_with_xai_head.xml",
1918
)
2019
xml_path = str(tmp_path / "a.xml")
21-
downloaded.save(xml_path)
20+
detector.save(xml_path)
2221
deserialized = Model.create_model(xml_path)
2322

2423
assert (
2524
deserialized.get_model()
2625
.get_rt_info(["model_info", "embedded_processing"])
2726
.astype(bool)
2827
)
29-
assert type(downloaded) is type(deserialized)
30-
for attr in downloaded.parameters():
31-
assert getattr(downloaded, attr) == getattr(deserialized, attr)
28+
assert type(detector) is type(deserialized)
29+
for attr in detector.parameters():
30+
assert getattr(detector, attr) == getattr(deserialized, attr)
3231

3332

34-
def test_classifier_save(tmp_path):
35-
downloaded = Model.create_model(
36-
"efficientnet-b0-pytorch", configuration={"scale_values": [1, 1, 1], "topk": 6}
33+
def test_classifier_save(tmp_path, data):
34+
classifier = Model.create_model(
35+
Path(data) / "otx_models/tinynet_imagenet.xml",
3736
)
3837
xml_path = str(tmp_path / "a.xml")
39-
downloaded.save(xml_path)
38+
classifier.save(xml_path)
4039
deserialized = Model.create_model(xml_path)
4140

4241
assert (
4342
deserialized.get_model()
4443
.get_rt_info(["model_info", "embedded_processing"])
4544
.astype(bool)
4645
)
47-
assert type(downloaded) is type(deserialized)
48-
for attr in downloaded.parameters():
49-
assert getattr(downloaded, attr) == getattr(deserialized, attr)
46+
assert type(classifier) is type(deserialized)
47+
for attr in classifier.parameters():
48+
assert getattr(classifier, attr) == getattr(deserialized, attr)
5049

5150

52-
def test_segmentor_save(tmp_path):
53-
downloaded = Model.create_model(
54-
"hrnet-v2-c1-segmentation",
55-
configuration={"reverse_input_channels": True, "labels": ["first", "second"]},
51+
def test_segmentor_save(tmp_path, data):
52+
segmenter = Model.create_model(
53+
Path(data) / "otx_models/Lite-hrnet-18_mod2.xml",
5654
)
5755
xml_path = str(tmp_path / "a.xml")
58-
downloaded.save(xml_path)
56+
segmenter.save(xml_path)
5957
deserialized = Model.create_model(xml_path)
6058

6159
assert (
6260
deserialized.get_model()
6361
.get_rt_info(["model_info", "embedded_processing"])
6462
.astype(bool)
6563
)
66-
assert type(downloaded) is type(deserialized)
67-
for attr in downloaded.parameters():
68-
assert getattr(downloaded, attr) == getattr(deserialized, attr)
64+
assert type(segmenter) is type(deserialized)
65+
for attr in segmenter.parameters():
66+
assert getattr(segmenter, attr) == getattr(deserialized, attr)
6967

7068

7169
def test_onnx_save(tmp_path, data):

0 commit comments

Comments
 (0)