Skip to content

Commit d68554e

Browse files
update mmar tests (#4091)
* update mmar tests Signed-off-by: Wenqi Li <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes pylint error Signed-off-by: Wenqi Li <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b6b2cfb commit d68554e

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

tests/ngc_mmar_loading.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
# limitations under the License.
1111

1212
import os
13+
import sys
1314
import unittest
1415

1516
import torch
1617
from parameterized import parameterized
1718

1819
from monai.apps.mmars import MODEL_DESC, load_from_mmar
1920
from monai.config import print_debug_info
21+
from monai.networks.utils import copy_model_state
2022

2123

2224
class TestAllDownloadingMMAR(unittest.TestCase):
@@ -26,10 +28,33 @@ def setUp(self):
2628

2729
@parameterized.expand((item,) for item in MODEL_DESC)
2830
def test_loading_mmar(self, item):
31+
if item["name"] == "clara_pt_self_supervised_learning_segmentation": # test the byow model
32+
default_model_file = os.path.join("ssl_models_2gpu", "best_metric_model.pt")
33+
pretrained_weights = load_from_mmar(
34+
item=item["name"],
35+
mmar_dir="./",
36+
map_location="cpu",
37+
api=True,
38+
model_file=default_model_file,
39+
weights_only=True,
40+
)
41+
pretrained_weights = {k.split(".", 1)[1]: v for k, v in pretrained_weights["state_dict"].items()}
42+
sys.path.append(os.path.join(f"{item['name']}", "custom")) # custom model folder
43+
from vit_network import ViTAutoEnc # pylint: disable=E0401
44+
45+
model = ViTAutoEnc(
46+
in_channels=1,
47+
img_size=(96, 96, 96),
48+
patch_size=(16, 16, 16),
49+
pos_embed="conv",
50+
hidden_size=768,
51+
mlp_dim=3072,
52+
)
53+
_, loaded, not_loaded = copy_model_state(model, pretrained_weights)
54+
self.assertTrue(len(loaded) > 0 and len(not_loaded) == 0)
55+
return
2956
if item["name"] == "clara_pt_fed_learning_brain_tumor_mri_segmentation":
3057
default_model_file = os.path.join("models", "server", "best_FL_global_model.pt")
31-
elif item["name"] == "clara_pt_self_supervised_learning_segmentation":
32-
default_model_file = os.path.join("models_2gpu", "best_metric_model.pt")
3358
else:
3459
default_model_file = None
3560
pretrained_model = load_from_mmar(

0 commit comments

Comments
 (0)