|
12 | 12 | from model_api.adapters.utils import load_parameters_from_onnx |
13 | 13 |
|
14 | 14 |
|
15 | | -def test_detector_save(tmp_path): |
16 | | - downloaded = Model.create_model( |
17 | | - "ssd_mobilenet_v1_fpn_coco", |
18 | | - configuration={"mean_values": [0, 0, 0], "confidence_threshold": 0.6}, |
| 15 | +def test_detector_save(tmp_path, data): |
| 16 | + detector = Model.create_model( |
| 17 | + Path(data) / "otx_models/detection_model_with_xai_head.xml", |
19 | 18 | ) |
20 | 19 | xml_path = str(tmp_path / "a.xml") |
21 | | - downloaded.save(xml_path) |
| 20 | + detector.save(xml_path) |
22 | 21 | deserialized = Model.create_model(xml_path) |
23 | 22 |
|
24 | 23 | assert ( |
25 | 24 | deserialized.get_model() |
26 | 25 | .get_rt_info(["model_info", "embedded_processing"]) |
27 | 26 | .astype(bool) |
28 | 27 | ) |
29 | | - assert type(downloaded) is type(deserialized) |
30 | | - for attr in downloaded.parameters(): |
31 | | - assert getattr(downloaded, attr) == getattr(deserialized, attr) |
| 28 | + assert type(detector) is type(deserialized) |
| 29 | + for attr in detector.parameters(): |
| 30 | + assert getattr(detector, attr) == getattr(deserialized, attr) |
32 | 31 |
|
33 | 32 |
|
34 | | -def test_classifier_save(tmp_path): |
35 | | - downloaded = Model.create_model( |
36 | | - "efficientnet-b0-pytorch", configuration={"scale_values": [1, 1, 1], "topk": 6} |
| 33 | +def test_classifier_save(tmp_path, data): |
| 34 | + classifier = Model.create_model( |
| 35 | + Path(data) / "otx_models/tinynet_imagenet.xml", |
37 | 36 | ) |
38 | 37 | xml_path = str(tmp_path / "a.xml") |
39 | | - downloaded.save(xml_path) |
| 38 | + classifier.save(xml_path) |
40 | 39 | deserialized = Model.create_model(xml_path) |
41 | 40 |
|
42 | 41 | assert ( |
43 | 42 | deserialized.get_model() |
44 | 43 | .get_rt_info(["model_info", "embedded_processing"]) |
45 | 44 | .astype(bool) |
46 | 45 | ) |
47 | | - assert type(downloaded) is type(deserialized) |
48 | | - for attr in downloaded.parameters(): |
49 | | - assert getattr(downloaded, attr) == getattr(deserialized, attr) |
| 46 | + assert type(classifier) is type(deserialized) |
| 47 | + for attr in classifier.parameters(): |
| 48 | + assert getattr(classifier, attr) == getattr(deserialized, attr) |
50 | 49 |
|
51 | 50 |
|
52 | | -def test_segmentor_save(tmp_path): |
53 | | - downloaded = Model.create_model( |
54 | | - "hrnet-v2-c1-segmentation", |
55 | | - configuration={"reverse_input_channels": True, "labels": ["first", "second"]}, |
| 51 | +def test_segmentor_save(tmp_path, data): |
| 52 | + segmenter = Model.create_model( |
| 53 | + Path(data) / "otx_models/Lite-hrnet-18_mod2.xml", |
56 | 54 | ) |
57 | 55 | xml_path = str(tmp_path / "a.xml") |
58 | | - downloaded.save(xml_path) |
| 56 | + segmenter.save(xml_path) |
59 | 57 | deserialized = Model.create_model(xml_path) |
60 | 58 |
|
61 | 59 | assert ( |
62 | 60 | deserialized.get_model() |
63 | 61 | .get_rt_info(["model_info", "embedded_processing"]) |
64 | 62 | .astype(bool) |
65 | 63 | ) |
66 | | - assert type(downloaded) is type(deserialized) |
67 | | - for attr in downloaded.parameters(): |
68 | | - assert getattr(downloaded, attr) == getattr(deserialized, attr) |
| 64 | + assert type(segmenter) is type(deserialized) |
| 65 | + for attr in segmenter.parameters(): |
| 66 | + assert getattr(segmenter, attr) == getattr(deserialized, attr) |
69 | 67 |
|
70 | 68 |
|
71 | 69 | def test_onnx_save(tmp_path, data): |
|
0 commit comments