Skip to content

Commit 692fac1

Browse files
Fix tests if vit_t not available
1 parent 79e441a commit 692fac1

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

test/test_instance_segmentation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
class TestInstanceSegmentation(unittest.TestCase):
13+
model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b"
1314
embedding_path = "./tmp_embeddings.zarr"
1415
tile_shape = (512, 512)
1516
halo = (96, 96)
@@ -37,8 +38,8 @@ def write_object(center, radius):
3738
return mask, image
3839

3940
@staticmethod
40-
def _get_model(image):
41-
predictor = util.get_sam_model(model_type="vit_t")
41+
def _get_model(image, model_type):
42+
predictor = util.get_sam_model(model_type=model_type)
4243
image_embeddings = util.precompute_image_embeddings(predictor, image)
4344
return predictor, image_embeddings
4445

@@ -47,7 +48,7 @@ def _get_model(image):
4748
@classmethod
4849
def setUpClass(cls):
4950
cls.mask, cls.image = cls._get_input()
50-
cls.predictor, cls.image_embeddings = cls._get_model(cls.image)
51+
cls.predictor, cls.image_embeddings = cls._get_model(cls.image, cls.model_type)
5152
cls.large_mask, cls.large_image = cls._get_input(shape=(1024, 1024))
5253
cls.tiled_embeddings = util.precompute_image_embeddings(
5354
cls.predictor, cls.large_image, save_path=cls.embedding_path, tile_shape=cls.tile_shape, halo=cls.halo

test/test_prompt_based_segmentation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88

99
class TestPromptBasedSegmentation(unittest.TestCase):
10+
model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b"
11+
1012
@staticmethod
1113
def _get_input(shape=(256, 256)):
1214
mask = np.zeros(shape, dtype="uint8")
@@ -17,8 +19,8 @@ def _get_input(shape=(256, 256)):
1719
return mask, image
1820

1921
@staticmethod
20-
def _get_model(image):
21-
predictor = util.get_sam_model(model_type="vit_t")
22+
def _get_model(image, model_type):
23+
predictor = util.get_sam_model(model_type=model_type)
2224
image_embeddings = util.precompute_image_embeddings(predictor, image)
2325
util.set_precomputed(predictor, image_embeddings)
2426
return predictor
@@ -28,7 +30,7 @@ def _get_model(image):
2830
@classmethod
2931
def setUpClass(cls):
3032
cls.mask, cls.image = cls._get_input()
31-
cls.predictor = cls._get_model(cls.image)
33+
cls.predictor = cls._get_model(cls.image, cls.model_type)
3234

3335
def test_segment_from_points(self):
3436
from micro_sam.prompt_based_segmentation import segment_from_points

test/test_training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import torch_em
1010

1111
from micro_sam.sample_data import synthetic_data
12+
from micro_sam.util import VIT_T_SUPPORT
1213

1314

15+
@unittest.skipUnless(VIT_T_SUPPORT, "Integration test is only run with vit_t support, otherwise it takes too long.")
1416
class TestTraining(unittest.TestCase):
1517
"""Integration test for training a SAM model.
1618
"""

test/test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def test_compute_iou(self):
3434
self.assertTrue(0.0 < compute_iou(x1, x2) < 1.0)
3535

3636
def test_tiled_prediction(self):
37-
from micro_sam.util import precompute_image_embeddings, get_sam_model
37+
from micro_sam.util import precompute_image_embeddings, get_sam_model, VIT_T_SUPPORT
3838

39-
predictor = get_sam_model(model_type="vit_t")
39+
predictor = get_sam_model(model_type="vit_t" if VIT_T_SUPPORT else "vit_b")
4040

4141
tile_shape, halo = (256, 256), (16, 16)
4242
input_ = np.random.rand(512, 512).astype("float32")

0 commit comments

Comments
 (0)