Skip to content

Commit 15e7549

Browse files
authored
Merge pull request #38 from kadirnar/remove_predict_function
remove predict function
2 parents faf9ca4 + 4e63462 commit 15e7549

File tree

6 files changed

+198
-100
lines changed

6 files changed

+198
-100
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor
2626

2727
# For image
2828

29-
autoseg_image = SegAutoMaskPredictor().save_image(
29+
autoseg_image = SegAutoMaskPredictor().image_predict(
3030
source="image.jpg",
3131
model_type="vit_l", # vit_l, vit_h, vit_b
3232
points_per_side=16,
@@ -36,7 +36,7 @@ autoseg_image = SegAutoMaskPredictor().save_image(
3636

3737
# For video
3838

39-
autoseg_video = SegAutoMaskPredictor().save_video(
39+
autoseg_video = SegAutoMaskPredictor().video_predict(
4040
source="video.mp4",
4141
model_type="vit_l", # vit_l, vit_h, vit_b
4242
points_per_side=16,
@@ -46,7 +46,7 @@ autoseg_video = SegAutoMaskPredictor().save_video(
4646

4747
# For manuel box and point selection
4848

49-
seg_manual_mask_generator = SegManualMaskPredictor().save_image(
49+
seg_manual_mask_generator = SegManualMaskPredictor().image_predict(
5050
source="image.jpg",
5151
model_type="vit_l", # vit_l, vit_h, vit_b
5252
input_point=[[100, 100], [200, 200]],
@@ -58,7 +58,7 @@ seg_manual_mask_generator = SegManualMaskPredictor().save_image(
5858

5959
# For multi box selection
6060

61-
seg_manual_mask_generator = SegManualMaskPredictor().save_image(
61+
seg_manual_mask_generator = SegManualMaskPredictor().image_predict(
6262
source="data/brain.png",
6363
model_type="vit_l",
6464
input_point=None,

metaseg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
from metaseg.generator.predictor import SamPredictor
1010
from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor
1111

12-
__version__ = "0.4.4"
12+
__version__ = "0.4.5"

metaseg/mask_predictor.py

Lines changed: 49 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from tqdm import tqdm
77

88
from metaseg import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
9-
from metaseg.utils import download_model, load_image, load_video
9+
from metaseg.utils import download_model, load_box, load_image, load_mask, load_video, multi_boxes
1010

1111

1212
class SegAutoMaskPredictor:
1313
def __init__(self):
1414
self.model = None
1515
self.device = "cuda" if torch.cuda.is_available() else "cpu"
16+
self.save = False
17+
self.show = False
1618

1719
def load_model(self, model_type):
1820
if self.model is None:
@@ -22,24 +24,17 @@ def load_model(self, model_type):
2224

2325
return self.model
2426

25-
def predict(self, frame, model_type, points_per_side, points_per_batch, min_area):
27+
def image_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.png"):
28+
read_image = load_image(source)
2629
model = self.load_model(model_type)
2730
mask_generator = SamAutomaticMaskGenerator(
2831
model, points_per_side=points_per_side, points_per_batch=points_per_batch, min_mask_region_area=min_area
2932
)
3033

31-
masks = mask_generator.generate(frame)
32-
33-
return frame, masks
34+
masks = mask_generator.generate(read_image)
3435

35-
def save_image(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.png"):
36-
read_image = load_image(source)
37-
image, anns = self.predict(read_image, model_type, points_per_side, points_per_batch, min_area)
38-
if len(anns) == 0:
39-
return
40-
41-
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
42-
mask_image = np.zeros((anns[0]["segmentation"].shape[0], anns[0]["segmentation"].shape[1], 3), dtype=np.uint8)
36+
sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True)
37+
mask_image = np.zeros((masks[0]["segmentation"].shape[0], masks[0]["segmentation"].shape[1], 3), dtype=np.uint8)
4338
colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8)
4439
for i, ann in enumerate(sorted_anns):
4540
m = ann["segmentation"]
@@ -53,12 +48,18 @@ def save_image(self, source, model_type, points_per_side, points_per_batch, min_
5348
img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0)
5449
mask_image = cv2.add(mask_image, img)
5550

56-
combined_mask = cv2.add(image, mask_image)
57-
cv2.imwrite(output_path, combined_mask)
51+
combined_mask = cv2.add(read_image, mask_image)
52+
if self.save:
53+
cv2.imwrite(output_path, combined_mask)
54+
55+
if self.show:
56+
cv2.imshow("Output", combined_mask)
57+
cv2.waitKey(0)
58+
cv2.destroyAllWindows()
5859

5960
return output_path
6061

61-
def save_video(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.mp4"):
62+
def video_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.mp4"):
6263
cap, out = load_video(source, output_path)
6364
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
6465
colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8)
@@ -68,18 +69,23 @@ def save_video(self, source, model_type, points_per_side, points_per_batch, min_
6869
if not ret:
6970
break
7071

71-
image, anns = self.predict(frame, model_type, points_per_side, points_per_batch, min_area)
72-
if len(anns) == 0:
72+
model = self.load_model(model_type)
73+
mask_generator = SamAutomaticMaskGenerator(
74+
model, points_per_side=points_per_side, points_per_batch=points_per_batch, min_mask_region_area=min_area
75+
)
76+
masks = mask_generator.generate(frame)
77+
78+
if len(masks) == 0:
7379
continue
7480

75-
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
81+
sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True)
7682
mask_image = np.zeros(
77-
(anns[0]["segmentation"].shape[0], anns[0]["segmentation"].shape[1], 3), dtype=np.uint8
83+
(masks[0]["segmentation"].shape[0], masks[0]["segmentation"].shape[1], 3), dtype=np.uint8
7884
)
7985

8086
for i, ann in enumerate(sorted_anns):
8187
m = ann["segmentation"]
82-
color = colors[i % 256] # Her nesne için farklı bir renk kullan
88+
color = colors[i % 256]
8389
img = np.zeros((m.shape[0], m.shape[1], 3), dtype=np.uint8)
8490
img[:, :, 0] = color[0]
8591
img[:, :, 1] = color[1]
@@ -102,6 +108,8 @@ class SegManualMaskPredictor:
102108
def __init__(self):
103109
self.model = None
104110
self.device = "cuda" if torch.cuda.is_available() else "cpu"
111+
self.save = False
112+
self.show = False
105113

106114
def load_model(self, model_type):
107115
if self.model is None:
@@ -111,49 +119,35 @@ def load_model(self, model_type):
111119

112120
return self.model
113121

114-
def load_mask(self, mask, random_color):
115-
if random_color:
116-
color = np.random.rand(3) * 255
117-
else:
118-
color = np.array([100, 50, 0])
119-
120-
h, w = mask.shape[-2:]
121-
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
122-
mask_image = mask_image.astype(np.uint8)
123-
return mask_image
124-
125-
def load_box(self, box, image):
126-
x, y, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
127-
cv2.rectangle(image, (x, y), (w, h), (0, 255, 0), 2)
128-
return image
129-
130-
def multi_boxes(self, boxes, predictor, image):
131-
input_boxes = torch.tensor(boxes, device=predictor.device)
132-
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
133-
return input_boxes, transformed_boxes
134-
135-
def predict(
122+
def image_predict(
136123
self,
137-
frame,
124+
source,
138125
model_type,
139126
input_box=None,
140127
input_point=None,
141128
input_label=None,
142129
multimask_output=False,
130+
output_path="output.png",
143131
):
132+
image = load_image(source)
144133
model = self.load_model(model_type)
145134
predictor = SamPredictor(model)
146-
predictor.set_image(frame)
135+
predictor.set_image(image)
147136

148137
if type(input_box[0]) == list:
149-
input_boxes, new_boxes = self.multi_boxes(input_box, predictor, frame)
138+
input_boxes, new_boxes = multi_boxes(input_box, predictor, image)
150139

151140
masks, _, _ = predictor.predict_torch(
152141
point_coords=None,
153142
point_labels=None,
154143
boxes=new_boxes,
155144
multimask_output=False,
156145
)
146+
for mask in masks:
147+
mask_image = load_mask(mask.cpu().numpy(), False)
148+
149+
for box in input_boxes:
150+
image = load_box(box.cpu().numpy(), image)
157151

158152
elif type(input_box[0]) == int:
159153
input_boxes = np.array(input_box)[None, :]
@@ -164,36 +158,16 @@ def predict(
164158
box=input_boxes,
165159
multimask_output=multimask_output,
166160
)
167-
168-
return frame, masks, input_boxes
169-
170-
def save_image(
171-
self,
172-
source,
173-
model_type,
174-
input_box=None,
175-
input_point=None,
176-
input_label=None,
177-
multimask_output=False,
178-
output_path="output.png",
179-
):
180-
read_image = load_image(source)
181-
image, anns, boxes = self.predict(read_image, model_type, input_box, input_point, input_label, multimask_output)
182-
if len(anns) == 0:
183-
return
184-
185-
if type(input_box[0]) == list:
186-
for mask in anns:
187-
mask_image = self.load_mask(mask.cpu().numpy(), False)
188-
189-
for box in boxes:
190-
image = self.load_box(box.cpu().numpy(), image)
191-
192-
elif type(input_box[0]) == int:
193-
mask_image = self.load_mask(anns, True)
194-
image = self.load_box(input_box, image)
161+
mask_image = load_mask(masks, True)
162+
image = load_box(input_box, image)
195163

196164
combined_mask = cv2.add(image, mask_image)
197-
cv2.imwrite(output_path, combined_mask)
165+
if self.save:
166+
cv2.imwrite(output_path, combined_mask)
167+
168+
if self.show:
169+
cv2.imshow("Output", combined_mask)
170+
cv2.waitKey(0)
171+
cv2.destroyAllWindows()
198172

199173
return output_path

metaseg/sahi_predict.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,45 @@
1-
from metaseg import SegManualMaskPredictor
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import torch
24

5+
from metaseg import SamPredictor, sam_model_registry
6+
from metaseg.utils import download_model, load_image, multi_boxes, plt_load_box, plt_load_mask
37

4-
def sahi_predictor(image_path, model_type, model_path, conf_th, device):
8+
9+
def sahi_predict(
10+
image_path,
11+
detection_model_type,
12+
detection_model_path,
13+
conf_th,
14+
image_size,
15+
slice_height,
16+
slice_width,
17+
overlap_height_ratio,
18+
overlap_width_ratio,
19+
):
520

621
try:
722
from sahi import AutoDetectionModel
8-
from sahi.predict import get_prediction, get_sliced_prediction, predict
23+
from sahi.predict import get_prediction, get_sliced_prediction
924
except ImportError:
1025
raise ImportError("Please install SAHI library using 'pip install sahi'.")
1126

27+
device = "cuda" if torch.cuda.is_available() else "cpu"
28+
1229
detection_model = AutoDetectionModel.from_pretrained(
13-
model_type=model_type,
14-
model_path=model_path,
30+
image_size=image_size,
31+
model_type=detection_model_type,
32+
model_path=detection_model_path,
1533
confidence_threshold=conf_th,
1634
device=device,
1735
)
1836
result = get_sliced_prediction(
1937
image_path,
2038
detection_model,
21-
slice_height=256,
22-
slice_width=256,
23-
overlap_height_ratio=0.2,
24-
overlap_width_ratio=0.2,
39+
slice_height=slice_height,
40+
slice_width=slice_width,
41+
overlap_height_ratio=overlap_height_ratio,
42+
overlap_width_ratio=overlap_width_ratio,
2543
)
2644

2745
result = get_prediction(image_path, detection_model)
@@ -30,12 +48,59 @@ def sahi_predictor(image_path, model_type, model_path, conf_th, device):
3048
for i in output:
3149
boxes.append(i.bbox.to_xyxy())
3250

33-
seg_manual_mask_generator = SegManualMaskPredictor().save_image(
34-
source=image_path,
35-
model_type="vit_l",
51+
return boxes
52+
53+
54+
class SahiPredictor:
55+
def __init__(self):
56+
self.model = None
57+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
58+
59+
def load_model(self, model_type):
60+
if self.model is None:
61+
self.model_path = download_model(model_type)
62+
self.model = sam_model_registry[model_type](checkpoint=self.model_path)
63+
self.model.to(device=self.device)
64+
65+
def save_image(
66+
self,
67+
source,
68+
model_type,
69+
input_box=None,
3670
input_point=None,
3771
input_label=None,
38-
input_box=boxes,
3972
multimask_output=False,
40-
)
41-
return seg_manual_mask_generator
73+
):
74+
read_image = load_image(source)
75+
model = self.load_model(model_type)
76+
predictor = SamPredictor(model)
77+
predictor.set_image(read_image)
78+
79+
if type(input_box[0]) == list:
80+
input_boxes, new_boxes = multi_boxes(input_box, predictor, read_image)
81+
82+
masks, _, _ = predictor.predict_torch(
83+
point_coords=None,
84+
point_labels=None,
85+
boxes=new_boxes,
86+
multimask_output=False,
87+
)
88+
89+
elif type(input_box[0]) == int:
90+
input_boxes = np.array(input_box)[None, :]
91+
92+
masks, _, _ = predictor.predict(
93+
point_coords=input_point,
94+
point_labels=input_label,
95+
box=input_boxes,
96+
multimask_output=multimask_output,
97+
)
98+
99+
plt.figure(figsize=(10, 10))
100+
plt.imshow(read_image)
101+
for mask in masks:
102+
plt_load_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
103+
for box in input_boxes:
104+
plt_load_box(box.cpu().numpy(), plt.gca())
105+
plt.axis("off")
106+
plt.show()

metaseg/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from metaseg.utils.data_utils import load_image, load_video
8-
from metaseg.utils.file_utils import download_model
7+
from metaseg.utils.data_utils import *
8+
from metaseg.utils.file_utils import *

0 commit comments

Comments
 (0)