1+ import os
2+ import argparse
3+ import numpy as np
4+ import torch
5+ import rasterio
6+ import matplotlib .pyplot as plt
7+ import groundingdino .datasets .transforms as T
8+ from PIL import Image
9+ from rasterio .plot import show
10+ from matplotlib .patches import Rectangle
11+ from groundingdino .models import build_model
12+ from groundingdino .util import box_ops
13+ from groundingdino .util .inference import predict
14+ from groundingdino .util .slconfig import SLConfig
15+ from groundingdino .util .utils import clean_state_dict
16+ from huggingface_hub import hf_hub_download
17+ from segment_anything import sam_model_registry
18+ from segment_anything import SamPredictor
19+
20+ SAM_MODELS = {
21+ "vit_h" : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" ,
22+ "vit_l" : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" ,
23+ "vit_b" : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
24+ }
25+
26+ CACHE_PATH = os .environ .get ("TORCH_HOME" , os .path .expanduser ("~/.cache/torch/hub/checkpoints" ))
27+
28+ def load_model_hf (repo_id , filename , ckpt_config_filename , device = 'cpu' ):
29+ cache_config_file = hf_hub_download (repo_id = repo_id , filename = ckpt_config_filename )
30+ args = SLConfig .fromfile (cache_config_file )
31+ model = build_model (args )
32+ model .to (device )
33+ cache_file = hf_hub_download (repo_id = repo_id , filename = filename )
34+ checkpoint = torch .load (cache_file , map_location = 'cpu' )
35+ model .load_state_dict (clean_state_dict (checkpoint ['model' ]), strict = False )
36+ model .eval ()
37+ return model
38+
39+ def transform_image (image ) -> torch .Tensor :
40+ transform = T .Compose ([
41+ T .RandomResize ([800 ], max_size = 1333 ),
42+ T .ToTensor (),
43+ T .Normalize ([0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ]),
44+ ])
45+ image_transformed , _ = transform (image , None )
46+ return image_transformed
47+
48+ # Class definition for LangSAM
49+ class LangSAM ():
50+ def __init__ (self , sam_type = "vit_h" ):
51+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
52+ self .build_groundingdino ()
53+ self .build_sam (sam_type )
54+
55+ def build_sam (self , sam_type ):
56+ checkpoint_url = SAM_MODELS [sam_type ]
57+ sam = sam_model_registry [sam_type ]()
58+ state_dict = torch .hub .load_state_dict_from_url (checkpoint_url )
59+ sam .load_state_dict (state_dict , strict = True )
60+ sam .to (device = self .device )
61+ self .sam = SamPredictor (sam )
62+
63+ def build_groundingdino (self ):
64+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
65+ ckpt_filename = "groundingdino_swinb_cogcoor.pth"
66+ ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
67+ self .groundingdino = load_model_hf (ckpt_repo_id , ckpt_filename , ckpt_config_filename , self .device )
68+
69+ def predict_dino (self , image_pil , text_prompt , box_threshold , text_threshold ):
70+ image_trans = transform_image (image_pil )
71+ boxes , logits , phrases = predict (model = self .groundingdino ,
72+ image = image_trans ,
73+ caption = text_prompt ,
74+ box_threshold = box_threshold ,
75+ text_threshold = text_threshold ,
76+ device = self .device )
77+ W , H = image_pil .size
78+ boxes = box_ops .box_cxcywh_to_xyxy (boxes ) * torch .Tensor ([W , H , W , H ])
79+
80+ return boxes , logits , phrases
81+
82+ def predict_sam (self , image_pil , boxes ):
83+ image_array = np .asarray (image_pil )
84+ self .sam .set_image (image_array )
85+ transformed_boxes = self .sam .transform .apply_boxes_torch (boxes , image_array .shape [:2 ])
86+ masks , _ , _ = self .sam .predict_torch (
87+ point_coords = None ,
88+ point_labels = None ,
89+ boxes = transformed_boxes .to (self .sam .device ),
90+ multimask_output = False ,
91+ )
92+ return masks .cpu ()
93+
94+ def predict (self , image_pil , text_prompt , box_threshold , text_threshold ):
95+ boxes , logits , phrases = self .predict_dino (image_pil , text_prompt , box_threshold , text_threshold )
96+ masks = torch .tensor ([])
97+ if len (boxes ) > 0 :
98+ masks = self .predict_sam (image_pil , boxes )
99+ masks = masks .squeeze (1 )
100+ return masks , boxes , phrases , logits
101+
102+ def main ():
103+ parser = argparse .ArgumentParser (description = 'LangSAM' )
104+ parser .add_argument ('--image' , required = True , help = 'path to the image' )
105+ parser .add_argument ('--prompt' , required = True , help = 'text prompt' )
106+ parser .add_argument ('--box_threshold' , default = 0.5 , type = float , help = 'box threshold' )
107+ parser .add_argument ('--text_threshold' , default = 0.5 , type = float , help = 'text threshold' )
108+ args = parser .parse_args ()
109+
110+ with rasterio .open (args .image ) as src :
111+ image_np = src .read ().transpose ((1 , 2 , 0 )) # Convert rasterio image to numpy array
112+ transform = src .transform # Save georeferencing information
113+ crs = src .crs # Save the Coordinate Reference System
114+
115+ model = LangSAM ()
116+
117+ image_pil = Image .fromarray (image_np [:, :, :3 ]) # Convert numpy array to PIL image, excluding the alpha channel
118+ image_np_copy = image_np .copy () # Create a copy for modifications
119+
120+ masks , boxes , phrases , logits = model .predict (image_pil , args .prompt , args .box_threshold , args .text_threshold )
121+
122+ if boxes .nelement () == 0 : # No "object" instances found
123+ print ('No objects found in the image.' )
124+ else :
125+ # Create an empty image to store the mask overlays
126+ mask_overlay = np .zeros_like (image_np [..., 0 ], dtype = np .int64 ) # Adjusted for single channel
127+
128+ for i in range (len (boxes )):
129+ box = boxes [i ].cpu ().numpy () # Convert the tensor to a numpy array
130+ mask = masks [i ].cpu ().numpy () # Convert the tensor to a numpy array
131+
132+ # Add the mask to the mask_overlay image
133+ mask_overlay += ((mask > 0 ) * (i + 1 )) # Assign a unique value for each mask
134+
135+ # Normalize mask_overlay to be in [0, 255]
136+ mask_overlay = ((mask_overlay > 0 ) * 255 ).astype (rasterio .uint8 ) # Binary mask in [0, 255]
137+
138+ with rasterio .open (
139+ 'mask.tif' ,
140+ 'w' ,
141+ driver = 'GTiff' ,
142+ height = mask_overlay .shape [0 ],
143+ width = mask_overlay .shape [1 ],
144+ count = 1 ,
145+ dtype = mask_overlay .dtype ,
146+ crs = crs ,
147+ transform = transform ,
148+ ) as dst :
149+ dst .write (mask_overlay , 1 )
150+
151+ if __name__ == '__main__' :
152+ main ()
0 commit comments