1+ import torch
2+ from sam2 .build_sam import build_sam2
3+ from sam2 .sam2_image_predictor import SAM2ImagePredictor
4+ from PIL import Image
5+ import numpy as np
6+ from matplotlib import pyplot as plt
7+ import os
8+
9+ def sam2_image_seg (image_path , prompt = None , output_path = None ):
10+ """
11+ Segment an image using SAM2.
12+
13+ Args:
14+ image_path (str): Path to the input image
15+ prompt (dict, optional): Dictionary with prompts for the model.
16+ Can contain 'point_coords', 'point_labels', and/or 'box'.
17+ Example: {'point_coords': np.array([[x, y]]), 'point_labels': np.array([1])}
18+ output_path (str, optional): Path to save visualization. If None, no visualization is saved.
19+
20+ Returns:
21+ np.ndarray: Segmentation mask
22+ """
23+ # Model paths
24+ checkpoint = "/home/ti_wang/AmadeusGPT/sam2/checkpoints/sam2.1_hiera_small.pt"
25+ model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
26+
27+ # Initialize predictor
28+ predictor = SAM2ImagePredictor (build_sam2 (model_cfg , checkpoint ))
29+
30+ # Set device
31+ device = "cuda" if torch .cuda .is_available () else "cpu"
32+ predictor .model = predictor .model .to (device )
33+
34+ # Load image
35+ image = Image .open (image_path )
36+ image = np .array (image .convert ("RGB" ))
37+ predictor .set_image (image )
38+
39+ # Run prediction with appropriate precision
40+ if device == "cuda" :
41+ with torch .inference_mode (), torch .autocast ("cuda" , dtype = torch .bfloat16 ):
42+ if prompt is None :
43+ masks , scores , logits = predictor .predict ()
44+ else :
45+ masks , scores , logits = predictor .predict (
46+ point_coords = prompt .get ('point_coords' , None ),
47+ point_labels = prompt .get ('point_labels' , None ),
48+ box = prompt .get ('box' , None ),
49+ multimask_output = True
50+ )
51+ else :
52+ with torch .inference_mode ():
53+ if prompt is None :
54+ masks , scores , logits = predictor .predict ()
55+ else :
56+ masks , scores , logits = predictor .predict (
57+ point_coords = prompt .get ('point_coords' , None ),
58+ point_labels = prompt .get ('point_labels' , None ),
59+ box = prompt .get ('box' , None ),
60+ multimask_output = True
61+ )
62+
63+ # Save visualization if output_path is provided
64+ if output_path is not None :
65+ # Create output directory if it doesn't exist
66+ os .makedirs (os .path .dirname (output_path ), exist_ok = True )
67+
68+ plt .figure (figsize = (10 , 10 ))
69+ plt .imshow (image ) # Show the original image
70+ plt .imshow (masks [0 ], cmap = "jet" , alpha = 0.5 ) # Overlay the first mask with transparency
71+ plt .axis ("off" ) # Remove axes for better visualization
72+ plt .title ("Image with Predicted Mask" )
73+ plt .savefig (output_path )
74+ plt .close ()
75+
76+ return masks
77+
78+
79+ if __name__ == "__main__" :
80+
81+ # checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
82+ # model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
83+ # checkpoint = "/home/ti_wang/AmadeusGPT/sam2/checkpoints/sam2.1_hiera_small.pt"
84+ # model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
85+
86+ # predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
87+
88+ # with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
89+ # image_path = "./notebooks/images/cars.jpg"
90+
91+ # image = Image.open(image_path)
92+ # image = np.array(image.convert("RGB"))
93+ # # image = np.array(image)
94+
95+ # predictor.set_image(image)
96+ # masks, _, _ = predictor.predict()
97+
98+ # # Plot the original image and overlay the mask
99+ # plt.figure(figsize=(10, 10))
100+ # plt.imshow(image) # Show the original image
101+ # plt.imshow(masks[0], cmap="jet", alpha=0.5) # Overlay the first mask with transparency
102+ # plt.axis("off") # Remove axes for better visualization
103+ # plt.title("Image with Predicted Mask")
104+ # plt.savefig("./test_images/mask_overlay.png") # Save the figure
105+
106+ image_path = "./sam2/notebooks/images/truck.jpg"
107+ output_path = "./ti_test/mask_overlay_2.png"
108+ sam2_image_seg (image_path , output_path = output_path )
0 commit comments