Skip to content

Commit 64396b8

Browse files
authored
Merge pull request #39 from kadirnar/add_sahi
Add return to load_model function
2 parents 15e7549 + 59b273a commit 64396b8

File tree

5 files changed

+83
-41
lines changed

5 files changed

+83
-41
lines changed

README.md

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,48 +26,71 @@ from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor
2626

2727
# For image
2828

29-
autoseg_image = SegAutoMaskPredictor().image_predict(
29+
results = SegAutoMaskPredictor().image_predict(
3030
source="image.jpg",
3131
model_type="vit_l", # vit_l, vit_h, vit_b
3232
points_per_side=16,
3333
points_per_batch=64,
3434
min_area=0,
35+
output_path="output.jpg",
36+
show=True,
37+
save=False,
3538
)
3639

3740
# For video
3841

39-
autoseg_video = SegAutoMaskPredictor().video_predict(
42+
results = SegAutoMaskPredictor().video_predict(
4043
source="video.mp4",
4144
model_type="vit_l", # vit_l, vit_h, vit_b
4245
points_per_side=16,
4346
points_per_batch=64,
4447
min_area=1000,
48+
output_path="output.mp4",
4549
)
4650

4751
# For manuel box and point selection
4852

49-
seg_manual_mask_generator = SegManualMaskPredictor().image_predict(
53+
results = SegManualMaskPredictor().image_predict(
5054
source="image.jpg",
5155
model_type="vit_l", # vit_l, vit_h, vit_b
5256
input_point=[[100, 100], [200, 200]],
5357
input_label=[0, 1],
54-
input_box=[100, 100, 200, 200], # x,y,w,h
58+
input_box=[100, 100, 200, 200], # or [[100, 100, 200, 200], [100, 100, 200, 200]]
5559
multimask_output=False,
56-
60+
random_color=False,
61+
show=True,
62+
save=False,
5763
)
64+
```
5865

59-
# For multi box selection
66+
# SAHI + Segment Anything
6067

61-
seg_manual_mask_generator = SegManualMaskPredictor().image_predict(
62-
source="data/brain.png",
63-
model_type="vit_l",
64-
input_point=None,
65-
input_label=None,
66-
input_box= [[100, 100, 400, 400]],
67-
multimask_output=False,
68+
```python
69+
image_path = "test.jpg"
70+
boxes = sahi_sliced_predict(
71+
image_path=image_path,
72+
detection_model_type="yolov5", #yolov8, detectron2, mmdetection, torchvision
73+
detection_model_path="yolov5l6.pt",
74+
conf_th=0.25,
75+
image_size=1280,
76+
slice_height=256,
77+
slice_width=256,
78+
overlap_height_ratio=0.2,
79+
overlap_width_ratio=0.2,
80+
)
6881

82+
SahiAutoSegmentation().save_image(
83+
source=image_path,
84+
model_type="vit_b",
85+
input_box=boxes,
86+
multimask_output=False,
87+
random_color=False,
88+
show=True,
89+
save=False,
6990
)
7091
```
92+
<img width="1000" alt="teaser" src="https://github.com/kadirnar/segment-anything-pip/releases/download/v0.5.0/sahi_autoseg.png">
93+
7194
# Extra Features
7295

7396
- [x] Support for Yolov5/8, Detectron2, Mmdetection, Torchvision models

metaseg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
from metaseg.generator.predictor import SamPredictor
1010
from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor
1111

12-
__version__ = "0.4.5"
12+
__version__ = "0.5.0"

metaseg/mask_predictor.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66
from tqdm import tqdm
77

88
from metaseg import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
9-
from metaseg.utils import download_model, load_box, load_image, load_mask, load_video, multi_boxes
9+
from metaseg.utils import download_model, load_box, load_image, load_mask, load_video, multi_boxes,show_image, save_image
1010

1111

1212
class SegAutoMaskPredictor:
1313
def __init__(self):
1414
self.model = None
1515
self.device = "cuda" if torch.cuda.is_available() else "cpu"
16-
self.save = False
17-
self.show = False
1816

1917
def load_model(self, model_type):
2018
if self.model is None:
@@ -24,7 +22,7 @@ def load_model(self, model_type):
2422

2523
return self.model
2624

27-
def image_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.png"):
25+
def image_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.png", show=False, save=False):
2826
read_image = load_image(source)
2927
model = self.load_model(model_type)
3028
mask_generator = SamAutomaticMaskGenerator(
@@ -49,15 +47,15 @@ def image_predict(self, source, model_type, points_per_side, points_per_batch, m
4947
mask_image = cv2.add(mask_image, img)
5048

5149
combined_mask = cv2.add(read_image, mask_image)
52-
if self.save:
53-
cv2.imwrite(output_path, combined_mask)
54-
55-
if self.show:
56-
cv2.imshow("Output", combined_mask)
57-
cv2.waitKey(0)
58-
cv2.destroyAllWindows()
59-
60-
return output_path
50+
self.combined_mask = combined_mask
51+
if show:
52+
show_image(combined_mask)
53+
54+
if save:
55+
save_image(output_path=output_path, image=combined_mask)
56+
57+
return masks
58+
6159

6260
def video_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.mp4"):
6361
cap, out = load_video(source, output_path)
@@ -128,6 +126,9 @@ def image_predict(
128126
input_label=None,
129127
multimask_output=False,
130128
output_path="output.png",
129+
random_color=False,
130+
show=False,
131+
save=False,
131132
):
132133
image = load_image(source)
133134
model = self.load_model(model_type)
@@ -144,7 +145,7 @@ def image_predict(
144145
multimask_output=False,
145146
)
146147
for mask in masks:
147-
mask_image = load_mask(mask.cpu().numpy(), False)
148+
mask_image = load_mask(mask.cpu().numpy(), random_color)
148149

149150
for box in input_boxes:
150151
image = load_box(box.cpu().numpy(), image)
@@ -158,16 +159,14 @@ def image_predict(
158159
box=input_boxes,
159160
multimask_output=multimask_output,
160161
)
161-
mask_image = load_mask(masks, True)
162+
mask_image = load_mask(masks, random_color)
162163
image = load_box(input_box, image)
163164

164165
combined_mask = cv2.add(image, mask_image)
165-
if self.save:
166-
cv2.imwrite(output_path, combined_mask)
166+
if save:
167+
save_image(output_path=output_path, image=combined_mask)
167168

168-
if self.show:
169-
cv2.imshow("Output", combined_mask)
170-
cv2.waitKey(0)
171-
cv2.destroyAllWindows()
169+
if show:
170+
show_image(combined_mask)
172171

173-
return output_path
172+
return masks

metaseg/sahi_predict.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from metaseg.utils import download_model, load_image, multi_boxes, plt_load_box, plt_load_mask
77

88

9-
def sahi_predict(
9+
def sahi_sliced_predict(
1010
image_path,
1111
detection_model_type,
1212
detection_model_path,
@@ -51,7 +51,7 @@ def sahi_predict(
5151
return boxes
5252

5353

54-
class SahiPredictor:
54+
class SahiAutoSegmentation:
5555
def __init__(self):
5656
self.model = None
5757
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -62,15 +62,21 @@ def load_model(self, model_type):
6262
self.model = sam_model_registry[model_type](checkpoint=self.model_path)
6363
self.model.to(device=self.device)
6464

65-
def save_image(
65+
return self.model
66+
67+
def predict(
6668
self,
6769
source,
6870
model_type,
6971
input_box=None,
7072
input_point=None,
7173
input_label=None,
7274
multimask_output=False,
75+
random_color=False,
76+
save=False,
77+
show=True,
7378
):
79+
7480
read_image = load_image(source)
7581
model = self.load_model(model_type)
7682
predictor = SamPredictor(model)
@@ -99,8 +105,11 @@ def save_image(
99105
plt.figure(figsize=(10, 10))
100106
plt.imshow(read_image)
101107
for mask in masks:
102-
plt_load_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
108+
plt_load_mask(mask.cpu().numpy(), plt.gca(), random_color=random_color)
103109
for box in input_boxes:
104110
plt_load_box(box.cpu().numpy(), plt.gca())
105111
plt.axis("off")
106-
plt.show()
112+
if save:
113+
plt.savefig("output.png")
114+
if show:
115+
plt.show()

metaseg/utils/data_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,14 @@ def multi_boxes(boxes, predictor, image):
7575
input_boxes = torch.tensor(boxes, device=predictor.device)
7676
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
7777
return input_boxes, transformed_boxes
78+
79+
def show_image(output_image):
80+
import cv2
81+
82+
cv2.imshow("output", output_image)
83+
cv2.waitKey(0)
84+
cv2.destroyAllWindows()
85+
86+
def save_image(output_image, output_path):
87+
import cv2
88+
cv2.imwrite(output_path, output_image)

0 commit comments

Comments
 (0)