Skip to content

Commit 86f007a

Browse files
Refactor peft test (#810)
Refactor PEFT test
1 parent 98c4b07 commit 86f007a

File tree

1 file changed

+14
-55
lines changed

1 file changed

+14
-55
lines changed

test/test_models/test_peft_sam.py

Lines changed: 14 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@
88
class TestPEFTSam(unittest.TestCase):
99
model_type = "vit_b"
1010

11-
def test_lora_sam(self):
12-
from micro_sam.models.peft_sam import PEFT_Sam, LoRASurgery
13-
14-
_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
15-
peft_sam = PEFT_Sam(sam, rank=2, peft_module=LoRASurgery)
16-
11+
def _check_output(self, peft_sam):
1712
shape = (3, 1024, 1024)
1813
expected_shape = (1, 3, 1024, 1024)
1914
with torch.no_grad():
@@ -22,90 +17,54 @@ def test_lora_sam(self):
2217
masks = output[0]["masks"]
2318
self.assertEqual(masks.shape, expected_shape)
2419

20+
def test_lora_sam(self):
21+
from micro_sam.models.peft_sam import PEFT_Sam, LoRASurgery
22+
23+
_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
24+
peft_sam = PEFT_Sam(sam, rank=2, peft_module=LoRASurgery)
25+
self._check_output(peft_sam)
26+
2527
def test_fact_sam(self):
2628
from micro_sam.models.peft_sam import PEFT_Sam, FacTSurgery
2729

2830
_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
2931
peft_sam = PEFT_Sam(sam, rank=2, peft_module=FacTSurgery)
30-
31-
shape = (3, 1024, 1024)
32-
expected_shape = (1, 3, 1024, 1024)
33-
with torch.no_grad():
34-
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
35-
output = peft_sam(batched_input, multimask_output=True)
36-
masks = output[0]["masks"]
37-
self.assertEqual(masks.shape, expected_shape)
32+
self._check_output(peft_sam)
3833

3934
def test_attention_layer_peft_sam(self):
4035
from micro_sam.models.peft_sam import PEFT_Sam, AttentionSurgery
4136

4237
_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
4338
peft_sam = PEFT_Sam(sam, rank=2, peft_module=AttentionSurgery)
44-
45-
shape = (3, 1024, 1024)
46-
expected_shape = (1, 3, 1024, 1024)
47-
with torch.no_grad():
48-
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
49-
output = peft_sam(batched_input, multimask_output=True)
50-
masks = output[0]["masks"]
51-
self.assertEqual(masks.shape, expected_shape)
39+
self._check_output(peft_sam)
5240

5341
def test_norm_layer_peft_sam(self):
5442
from micro_sam.models.peft_sam import PEFT_Sam, LayerNormSurgery
5543

5644
_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
5745
peft_sam = PEFT_Sam(sam, rank=2, peft_module=LayerNormSurgery)
58-
59-
shape = (3, 1024, 1024)
60-
expected_shape = (1, 3, 1024, 1024)
61-
with torch.no_grad():
62-
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
63-
output = peft_sam(batched_input, multimask_output=True)
64-
masks = output[0]["masks"]
65-
self.assertEqual(masks.shape, expected_shape)
46+
self._check_output(peft_sam)
6647

6748
def test_bias_layer_peft_sam(self):
6849
from micro_sam.models.peft_sam import PEFT_Sam, BiasSurgery
6950

7051
_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
7152
peft_sam = PEFT_Sam(sam, rank=2, peft_module=BiasSurgery)
72-
73-
shape = (3, 1024, 1024)
74-
expected_shape = (1, 3, 1024, 1024)
75-
with torch.no_grad():
76-
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
77-
output = peft_sam(batched_input, multimask_output=True)
78-
masks = output[0]["masks"]
79-
self.assertEqual(masks.shape, expected_shape)
53+
self._check_output(peft_sam)
8054

8155
def test_ssf_peft_sam(self):
8256
from micro_sam.models.peft_sam import PEFT_Sam, SSFSurgery
8357

8458
_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
8559
peft_sam = PEFT_Sam(sam, rank=2, peft_module=SSFSurgery)
86-
87-
shape = (3, 1024, 1024)
88-
expected_shape = (1, 3, 1024, 1024)
89-
with torch.no_grad():
90-
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
91-
output = peft_sam(batched_input, multimask_output=True)
92-
masks = output[0]["masks"]
93-
self.assertEqual(masks.shape, expected_shape)
60+
self._check_output(peft_sam)
9461

9562
def test_adaptformer_peft_sam(self):
9663
from micro_sam.models.peft_sam import PEFT_Sam, AdaptFormer
9764

9865
_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
9966
peft_sam = PEFT_Sam(sam, rank=2, peft_module=AdaptFormer, projection_size=64, alpha=2.0, dropout=0.5)
100-
101-
102-
shape = (3, 1024, 1024)
103-
expected_shape = (1, 3, 1024, 1024)
104-
with torch.no_grad():
105-
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
106-
output = peft_sam(batched_input, multimask_output=True)
107-
masks = output[0]["masks"]
108-
self.assertEqual(masks.shape, expected_shape)
67+
self._check_output(peft_sam)
10968

11069

11170
if __name__ == "__main__":

0 commit comments

Comments
 (0)