Skip to content

Commit 01cee1a

Browse files
committed
sam2 test image
1 parent 237a20a commit 01cee1a

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

sam2_image.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import torch
2+
from sam2.build_sam import build_sam2
3+
from sam2.sam2_image_predictor import SAM2ImagePredictor
4+
from PIL import Image
5+
import numpy as np
6+
from matplotlib import pyplot as plt
7+
import os
8+
9+
def sam2_image_seg(image_path, prompt=None, output_path=None):
10+
"""
11+
Segment an image using SAM2.
12+
13+
Args:
14+
image_path (str): Path to the input image
15+
prompt (dict, optional): Dictionary with prompts for the model.
16+
Can contain 'point_coords', 'point_labels', and/or 'box'.
17+
Example: {'point_coords': np.array([[x, y]]), 'point_labels': np.array([1])}
18+
output_path (str, optional): Path to save visualization. If None, no visualization is saved.
19+
20+
Returns:
21+
np.ndarray: Segmentation mask
22+
"""
23+
# Model paths
24+
checkpoint = "/home/ti_wang/AmadeusGPT/sam2/checkpoints/sam2.1_hiera_small.pt"
25+
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
26+
27+
# Initialize predictor
28+
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
29+
30+
# Set device
31+
device = "cuda" if torch.cuda.is_available() else "cpu"
32+
predictor.model = predictor.model.to(device)
33+
34+
# Load image
35+
image = Image.open(image_path)
36+
image = np.array(image.convert("RGB"))
37+
predictor.set_image(image)
38+
39+
# Run prediction with appropriate precision
40+
if device == "cuda":
41+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
42+
if prompt is None:
43+
masks, scores, logits = predictor.predict()
44+
else:
45+
masks, scores, logits = predictor.predict(
46+
point_coords=prompt.get('point_coords', None),
47+
point_labels=prompt.get('point_labels', None),
48+
box=prompt.get('box', None),
49+
multimask_output=True
50+
)
51+
else:
52+
with torch.inference_mode():
53+
if prompt is None:
54+
masks, scores, logits = predictor.predict()
55+
else:
56+
masks, scores, logits = predictor.predict(
57+
point_coords=prompt.get('point_coords', None),
58+
point_labels=prompt.get('point_labels', None),
59+
box=prompt.get('box', None),
60+
multimask_output=True
61+
)
62+
63+
# Save visualization if output_path is provided
64+
if output_path is not None:
65+
# Create output directory if it doesn't exist
66+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
67+
68+
plt.figure(figsize=(10, 10))
69+
plt.imshow(image) # Show the original image
70+
plt.imshow(masks[0], cmap="jet", alpha=0.5) # Overlay the first mask with transparency
71+
plt.axis("off") # Remove axes for better visualization
72+
plt.title("Image with Predicted Mask")
73+
plt.savefig(output_path)
74+
plt.close()
75+
76+
return masks
77+
78+
79+
if __name__ == "__main__":
80+
81+
# checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
82+
# model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
83+
# checkpoint = "/home/ti_wang/AmadeusGPT/sam2/checkpoints/sam2.1_hiera_small.pt"
84+
# model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
85+
86+
# predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
87+
88+
# with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
89+
# image_path = "./notebooks/images/cars.jpg"
90+
91+
# image = Image.open(image_path)
92+
# image = np.array(image.convert("RGB"))
93+
# # image = np.array(image)
94+
95+
# predictor.set_image(image)
96+
# masks, _, _ = predictor.predict()
97+
98+
# # Plot the original image and overlay the mask
99+
# plt.figure(figsize=(10, 10))
100+
# plt.imshow(image) # Show the original image
101+
# plt.imshow(masks[0], cmap="jet", alpha=0.5) # Overlay the first mask with transparency
102+
# plt.axis("off") # Remove axes for better visualization
103+
# plt.title("Image with Predicted Mask")
104+
# plt.savefig("./test_images/mask_overlay.png") # Save the figure
105+
106+
image_path = "./sam2/notebooks/images/truck.jpg"
107+
output_path = "./ti_test/mask_overlay_2.png"
108+
sam2_image_seg(image_path, output_path=output_path)

0 commit comments

Comments
 (0)