Skip to content

Commit 2a6b638

Browse files
authored
Added text prompting segmentation (#65)
1 parent 52f29fa commit 2a6b638

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed

samgeo/requirements.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
rasterio
2+
torch
3+
torchvision
4+
segment-anything
5+
huggingface_hub
6+
git+https://github.com/IDEA-Research/GroundingDINO.git

samgeo/text_sam.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import os
2+
import argparse
3+
import numpy as np
4+
import torch
5+
import rasterio
6+
import matplotlib.pyplot as plt
7+
import groundingdino.datasets.transforms as T
8+
from PIL import Image
9+
from rasterio.plot import show
10+
from matplotlib.patches import Rectangle
11+
from groundingdino.models import build_model
12+
from groundingdino.util import box_ops
13+
from groundingdino.util.inference import predict
14+
from groundingdino.util.slconfig import SLConfig
15+
from groundingdino.util.utils import clean_state_dict
16+
from huggingface_hub import hf_hub_download
17+
from segment_anything import sam_model_registry
18+
from segment_anything import SamPredictor
19+
20+
SAM_MODELS = {
21+
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
22+
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
23+
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
24+
}
25+
26+
CACHE_PATH = os.environ.get("TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints"))
27+
28+
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
29+
cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
30+
args = SLConfig.fromfile(cache_config_file)
31+
model = build_model(args)
32+
model.to(device)
33+
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
34+
checkpoint = torch.load(cache_file, map_location='cpu')
35+
model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
36+
model.eval()
37+
return model
38+
39+
def transform_image(image) -> torch.Tensor:
40+
transform = T.Compose([
41+
T.RandomResize([800], max_size=1333),
42+
T.ToTensor(),
43+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
44+
])
45+
image_transformed, _ = transform(image, None)
46+
return image_transformed
47+
48+
# Class definition for LangSAM
49+
class LangSAM():
50+
def __init__(self, sam_type="vit_h"):
51+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52+
self.build_groundingdino()
53+
self.build_sam(sam_type)
54+
55+
def build_sam(self, sam_type):
56+
checkpoint_url = SAM_MODELS[sam_type]
57+
sam = sam_model_registry[sam_type]()
58+
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
59+
sam.load_state_dict(state_dict, strict=True)
60+
sam.to(device=self.device)
61+
self.sam = SamPredictor(sam)
62+
63+
def build_groundingdino(self):
64+
ckpt_repo_id = "ShilongLiu/GroundingDINO"
65+
ckpt_filename = "groundingdino_swinb_cogcoor.pth"
66+
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
67+
self.groundingdino = load_model_hf(ckpt_repo_id, ckpt_filename, ckpt_config_filename, self.device)
68+
69+
def predict_dino(self, image_pil, text_prompt, box_threshold, text_threshold):
70+
image_trans = transform_image(image_pil)
71+
boxes, logits, phrases = predict(model=self.groundingdino,
72+
image=image_trans,
73+
caption=text_prompt,
74+
box_threshold=box_threshold,
75+
text_threshold=text_threshold,
76+
device=self.device)
77+
W, H = image_pil.size
78+
boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
79+
80+
return boxes, logits, phrases
81+
82+
def predict_sam(self, image_pil, boxes):
83+
image_array = np.asarray(image_pil)
84+
self.sam.set_image(image_array)
85+
transformed_boxes = self.sam.transform.apply_boxes_torch(boxes, image_array.shape[:2])
86+
masks, _, _ = self.sam.predict_torch(
87+
point_coords=None,
88+
point_labels=None,
89+
boxes=transformed_boxes.to(self.sam.device),
90+
multimask_output=False,
91+
)
92+
return masks.cpu()
93+
94+
def predict(self, image_pil, text_prompt, box_threshold, text_threshold):
95+
boxes, logits, phrases = self.predict_dino(image_pil, text_prompt, box_threshold, text_threshold)
96+
masks = torch.tensor([])
97+
if len(boxes) > 0:
98+
masks = self.predict_sam(image_pil, boxes)
99+
masks = masks.squeeze(1)
100+
return masks, boxes, phrases, logits
101+
102+
def main():
103+
parser = argparse.ArgumentParser(description='LangSAM')
104+
parser.add_argument('--image', required=True, help='path to the image')
105+
parser.add_argument('--prompt', required=True, help='text prompt')
106+
parser.add_argument('--box_threshold', default=0.5, type=float, help='box threshold')
107+
parser.add_argument('--text_threshold', default=0.5, type=float, help='text threshold')
108+
args = parser.parse_args()
109+
110+
with rasterio.open(args.image) as src:
111+
image_np = src.read().transpose((1, 2, 0)) # Convert rasterio image to numpy array
112+
transform = src.transform # Save georeferencing information
113+
crs = src.crs # Save the Coordinate Reference System
114+
115+
model = LangSAM()
116+
117+
image_pil = Image.fromarray(image_np[:, :, :3]) # Convert numpy array to PIL image, excluding the alpha channel
118+
image_np_copy = image_np.copy() # Create a copy for modifications
119+
120+
masks, boxes, phrases, logits = model.predict(image_pil, args.prompt, args.box_threshold, args.text_threshold)
121+
122+
if boxes.nelement() == 0: # No "object" instances found
123+
print('No objects found in the image.')
124+
else:
125+
# Create an empty image to store the mask overlays
126+
mask_overlay = np.zeros_like(image_np[..., 0], dtype=np.int64) # Adjusted for single channel
127+
128+
for i in range(len(boxes)):
129+
box = boxes[i].cpu().numpy() # Convert the tensor to a numpy array
130+
mask = masks[i].cpu().numpy() # Convert the tensor to a numpy array
131+
132+
# Add the mask to the mask_overlay image
133+
mask_overlay += ((mask > 0) * (i + 1)) # Assign a unique value for each mask
134+
135+
# Normalize mask_overlay to be in [0, 255]
136+
mask_overlay = ((mask_overlay > 0) * 255).astype(rasterio.uint8) # Binary mask in [0, 255]
137+
138+
with rasterio.open(
139+
'mask.tif',
140+
'w',
141+
driver='GTiff',
142+
height=mask_overlay.shape[0],
143+
width=mask_overlay.shape[1],
144+
count=1,
145+
dtype=mask_overlay.dtype,
146+
crs=crs,
147+
transform=transform,
148+
) as dst:
149+
dst.write(mask_overlay, 1)
150+
151+
if __name__ == '__main__':
152+
main()

0 commit comments

Comments
 (0)