88class TestPEFTSam (unittest .TestCase ):
99 model_type = "vit_b"
1010
11- def test_lora_sam (self ):
12- from micro_sam .models .peft_sam import PEFT_Sam , LoRASurgery
13-
14- _ , sam = util .get_sam_model (model_type = self .model_type , return_sam = True , device = "cpu" )
15- peft_sam = PEFT_Sam (sam , rank = 2 , peft_module = LoRASurgery )
16-
11+ def _check_output (self , peft_sam ):
1712 shape = (3 , 1024 , 1024 )
1813 expected_shape = (1 , 3 , 1024 , 1024 )
1914 with torch .no_grad ():
@@ -22,90 +17,54 @@ def test_lora_sam(self):
2217 masks = output [0 ]["masks" ]
2318 self .assertEqual (masks .shape , expected_shape )
2419
20+ def test_lora_sam (self ):
21+ from micro_sam .models .peft_sam import PEFT_Sam , LoRASurgery
22+
23+ _ , sam = util .get_sam_model (model_type = self .model_type , return_sam = True , device = "cpu" )
24+ peft_sam = PEFT_Sam (sam , rank = 2 , peft_module = LoRASurgery )
25+ self ._check_output (peft_sam )
26+
2527 def test_fact_sam (self ):
2628 from micro_sam .models .peft_sam import PEFT_Sam , FacTSurgery
2729
2830 _ , sam = util .get_sam_model (model_type = self .model_type , return_sam = True , device = "cpu" )
2931 peft_sam = PEFT_Sam (sam , rank = 2 , peft_module = FacTSurgery )
30-
31- shape = (3 , 1024 , 1024 )
32- expected_shape = (1 , 3 , 1024 , 1024 )
33- with torch .no_grad ():
34- batched_input = [{"image" : torch .rand (* shape ), "original_size" : shape [1 :]}]
35- output = peft_sam (batched_input , multimask_output = True )
36- masks = output [0 ]["masks" ]
37- self .assertEqual (masks .shape , expected_shape )
32+ self ._check_output (peft_sam )
3833
3934 def test_attention_layer_peft_sam (self ):
4035 from micro_sam .models .peft_sam import PEFT_Sam , AttentionSurgery
4136
4237 _ , sam = util .get_sam_model (model_type = self .model_type , return_sam = True , device = "cpu" )
4338 peft_sam = PEFT_Sam (sam , rank = 2 , peft_module = AttentionSurgery )
44-
45- shape = (3 , 1024 , 1024 )
46- expected_shape = (1 , 3 , 1024 , 1024 )
47- with torch .no_grad ():
48- batched_input = [{"image" : torch .rand (* shape ), "original_size" : shape [1 :]}]
49- output = peft_sam (batched_input , multimask_output = True )
50- masks = output [0 ]["masks" ]
51- self .assertEqual (masks .shape , expected_shape )
39+ self ._check_output (peft_sam )
5240
5341 def test_norm_layer_peft_sam (self ):
5442 from micro_sam .models .peft_sam import PEFT_Sam , LayerNormSurgery
5543
5644 _ , sam = util .get_sam_model (model_type = self .model_type , return_sam = True , device = "cpu" )
5745 peft_sam = PEFT_Sam (sam , rank = 2 , peft_module = LayerNormSurgery )
58-
59- shape = (3 , 1024 , 1024 )
60- expected_shape = (1 , 3 , 1024 , 1024 )
61- with torch .no_grad ():
62- batched_input = [{"image" : torch .rand (* shape ), "original_size" : shape [1 :]}]
63- output = peft_sam (batched_input , multimask_output = True )
64- masks = output [0 ]["masks" ]
65- self .assertEqual (masks .shape , expected_shape )
46+ self ._check_output (peft_sam )
6647
6748 def test_bias_layer_peft_sam (self ):
6849 from micro_sam .models .peft_sam import PEFT_Sam , BiasSurgery
6950
7051 _ , sam = util .get_sam_model (model_type = self .model_type , return_sam = True , device = "cpu" )
7152 peft_sam = PEFT_Sam (sam , rank = 2 , peft_module = BiasSurgery )
72-
73- shape = (3 , 1024 , 1024 )
74- expected_shape = (1 , 3 , 1024 , 1024 )
75- with torch .no_grad ():
76- batched_input = [{"image" : torch .rand (* shape ), "original_size" : shape [1 :]}]
77- output = peft_sam (batched_input , multimask_output = True )
78- masks = output [0 ]["masks" ]
79- self .assertEqual (masks .shape , expected_shape )
53+ self ._check_output (peft_sam )
8054
8155 def test_ssf_peft_sam (self ):
8256 from micro_sam .models .peft_sam import PEFT_Sam , SSFSurgery
8357
8458 _ , sam = util .get_sam_model (model_type = self .model_type , return_sam = True , device = "cpu" )
8559 peft_sam = PEFT_Sam (sam , rank = 2 , peft_module = SSFSurgery )
86-
87- shape = (3 , 1024 , 1024 )
88- expected_shape = (1 , 3 , 1024 , 1024 )
89- with torch .no_grad ():
90- batched_input = [{"image" : torch .rand (* shape ), "original_size" : shape [1 :]}]
91- output = peft_sam (batched_input , multimask_output = True )
92- masks = output [0 ]["masks" ]
93- self .assertEqual (masks .shape , expected_shape )
60+ self ._check_output (peft_sam )
9461
9562 def test_adaptformer_peft_sam (self ):
9663 from micro_sam .models .peft_sam import PEFT_Sam , AdaptFormer
9764
9865 _ , sam = util .get_sam_model (model_type = self .model_type , return_sam = True , device = "cpu" )
9966 peft_sam = PEFT_Sam (sam , rank = 2 , peft_module = AdaptFormer , projection_size = 64 , alpha = 2.0 , dropout = 0.5 )
100-
101-
102- shape = (3 , 1024 , 1024 )
103- expected_shape = (1 , 3 , 1024 , 1024 )
104- with torch .no_grad ():
105- batched_input = [{"image" : torch .rand (* shape ), "original_size" : shape [1 :]}]
106- output = peft_sam (batched_input , multimask_output = True )
107- masks = output [0 ]["masks" ]
108- self .assertEqual (masks .shape , expected_shape )
67+ self ._check_output (peft_sam )
10968
11069
11170if __name__ == "__main__" :
0 commit comments