Skip to content

Commit 45a4d31

Browse files
committed
remove SAM 1 (#62)
1 parent 08b500e commit 45a4d31

File tree

1 file changed

+1
-30
lines changed

1 file changed

+1
-30
lines changed

test/test_segmentor.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import os
2-
31
import cv2
42
import numpy as np
53
import torch
64

75
from roboreg.detector import OpenCVDetector
8-
from roboreg.segmentor import Sam2Segmentor, SamSegmentor
6+
from roboreg.segmentor import Sam2Segmentor
97

108

119
def test_sam2_segmentor() -> None:
@@ -30,32 +28,5 @@ def test_sam2_segmentor() -> None:
3028
cv2.waitKey(0)
3129

3230

33-
def test_sam_segmentor() -> None:
34-
img = cv2.imread("test/assets/lbr_med7/zed2i/high_res/image_1.png")
35-
36-
# detect
37-
detector = OpenCVDetector(n_positive_samples=5) # number of detected samples
38-
samples, labels = detector.detect(img)
39-
40-
# segment
41-
checkpoint = os.path.join(
42-
os.environ["HOME"],
43-
"Downloads/sam_checkpoints/sam_vit_h_4b8939.pth",
44-
)
45-
model_type = "vit_h"
46-
device = "cuda" if torch.cuda.is_available() else "cpu"
47-
48-
segmentor = SamSegmentor(
49-
checkpoint=checkpoint, model_type=model_type, device=device
50-
)
51-
p = segmentor(img, np.array(samples), np.array(labels))
52-
53-
# visualize
54-
cv2.imshow("masked_img", np.where(np.expand_dims(p > segmentor.pth, -1), img, 0))
55-
cv2.imshow("probability", p)
56-
cv2.waitKey(0)
57-
58-
5931
if __name__ == "__main__":
6032
test_sam2_segmentor()
61-
# test_sam_segmentor()

0 commit comments

Comments
 (0)