forked from IDEA-Research/Grounded-Segment-Anything
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgrounded_sam_imagine.py
More file actions
94 lines (80 loc) · 2.81 KB
/
grounded_sam_imagine.py
File metadata and controls
94 lines (80 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import cv2
import numpy as np
import supervision as sv
import torch
import torchvision
import sys
import os
from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor
# ========== Command-line Arguments ==========
SOURCE_IMAGE_PATH = sys.argv[1] # input image path
OUTPUT_IMAGE_PATH = sys.argv[2] # output transparent image path (.png)
CLASSES = sys.argv[3:] # target class list
# ============================================
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
grounding_dino_model = Model(
model_config_path=GROUNDING_DINO_CONFIG_PATH,
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH
)
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)
image = cv2.imread(SOURCE_IMAGE_PATH)
BOX_THRESHOLD = 0.25
TEXT_THRESHOLD = 0.25
detections = grounding_dino_model.predict_with_classes(
image=image,
classes=CLASSES,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD
)
# NMS
NMS_THRESHOLD = 0.8
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
NMS_THRESHOLD
).numpy().tolist()
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]
# Segmentation
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
sam_predictor.set_image(image)
result_masks = []
for box in xyxy:
masks, scores, logits = sam_predictor.predict(box=box, multimask_output=True)
result_masks.append(masks[np.argmax(scores)])
return np.array(result_masks)
detections.mask = segment(
sam_predictor=sam_predictor,
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
xyxy=detections.xyxy
)
# Compose transparent RGBA image
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w, _ = image.shape
final_image = np.zeros((h, w, 4), dtype=np.uint8)
areas = []
for box in detections.xyxy:
x1, y1, x2, y2 = box
area = (x2 - x1) * (y2 - y1)
areas.append(area)
areas = np.array(areas)
for i, mask in enumerate(detections.mask):
print(i)
if i != 0:
continue
mask = mask.astype(bool)
alpha = (mask * 255).astype(np.uint8)
rgb = image_rgb.copy()
rgb[~mask] = 0
rgba = np.dstack([rgb, alpha])
final_image = np.maximum(final_image, rgba)
cv2.imwrite(OUTPUT_IMAGE_PATH, cv2.cvtColor(final_image, cv2.COLOR_RGBA2BGRA))
print(f"Transparent image saved to: {OUTPUT_IMAGE_PATH}")