-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathobjectSegmentation.py
More file actions
188 lines (156 loc) · 6.74 KB
/
objectSegmentation.py
File metadata and controls
188 lines (156 loc) · 6.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from google import genai
from google.genai import types
from PIL import Image, ImageDraw
import io
import base64
import json
import numpy as np
import os
import load_dotenv
from difflib import get_close_matches
import re
load_dotenv.load_dotenv() # take environment variables from .env.
os.environ.get('GOOGLE_API_KEY')
client = genai.Client()
COMMON_SYNONYMS = {
'vegetables': ['veggie', 'veg', 'vegetable'],
'tomato': ['tomatoes', 'tomatos', 'tomatoe'],
'potato': ['potatoes', 'potatos', 'potatoe'],
'carrot': ['carrots', 'carot', 'carots'],
'pepper': ['peppers', 'capsicum', 'bell pepper', 'chili'],
'onion': ['onions', 'unions'],
'cucumber': ['cucumbers', 'cucumbre'],
'broccoli': ['brocoli', 'brocolli', 'broccolli'],
'cauliflower': ['cauliflour', 'cauli'],
'lettuce': ['letuce', 'lettus', 'lattuce'],
'zucchini': ['zuchini', 'courgette', 'zuccini'],
# Add more synonyms as needed
}
def find_matching_label(query: str, available_labels: list) -> str:
"""Find the best matching label for a given query"""
query = query.lower()
# Check exact match first
for label in available_labels:
if query == label.lower():
return label
# Check synonyms
for standard_name, variations in COMMON_SYNONYMS.items():
if query in variations or standard_name.lower() == query:
for label in available_labels:
if standard_name.lower() in label.lower():
return label
# Try fuzzy matching
matches = get_close_matches(query, [label.lower() for label in available_labels], n=1, cutoff=0.6)
if matches:
for label in available_labels:
if label.lower() == matches[0]:
return label
return None
def parse_json(json_output: str):
"""Parse JSON output from Gemini model response"""
try:
# Find the JSON content between ```json and ```
start = json_output.find("```json")
if start != -1:
# Move past ```json and any whitespace/newline
json_start = json_output.find('[', start)
if json_start == -1:
json_start = json_output.find('{', start)
if json_start != -1:
# Find the closing ```
end = json_output.find("```", json_start)
if end != -1:
json_content = json_output[json_start:end].strip()
# Parse JSON to validate it
json.loads(json_content) # This will raise an exception if invalid
return json_content
# If we couldn't extract valid JSON, try parsing the whole response
json.loads(json_output) # This will raise an exception if invalid
return json_output
except Exception as e:
print(f"Error parsing JSON: {e}")
print(f"Raw output: {json_output}")
return "[]" # Return empty array as fallback
def extract_segmentation_masks(image_path: str, output_dir: str = "segmentation_outputs"):
try:
# Load and get original image size before resizing
original_image = Image.open(image_path)
original_width, original_height = original_image.size
# Create thumbnail for processing
im = original_image.copy()
im.thumbnail([1024, 1024], Image.Resampling.LANCZOS)
print(f"Image size after thumbnail: {im.size}")
prompt = """
Analyze this image and provide segmentation masks for all visible objects. Ignore any text visible in the image.
Return the results as a JSON array where each object has:
- "label": descriptive name of the object
- "box_2d": bounding box coordinates [y0, x0, y1, x1] normalized to 1000x1000
- "mask": base64 PNG image of the segmentation mask
Format the response as a valid JSON array only.
"""
config = types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(thinking_budget=0)
)
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=[prompt, im],
config=config
)
# Parse JSON response with improved error handling
json_str = parse_json(response.text)
print(f"Parsed JSON string length: {len(json_str)}")
items = json.loads(json_str)
print(f"Found {len(items)} objects")
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Process each mask
for i, item in enumerate(items):
# Get normalized bounding box coordinates
box = item["box_2d"]
# Convert normalized coordinates (0-1000) to original image pixels
y0 = int(box[0] / 1000 * original_height)
x0 = int(box[1] / 1000 * original_width)
y1 = int(box[2] / 1000 * original_height)
x1 = int(box[3] / 1000 * original_width)
# Skip invalid boxes
if y0 >= y1 or x0 >= x1:
continue
# Process mask
png_str = item["mask"]
if not png_str.startswith("data:image/png;base64,"):
continue
# Remove prefix and decode base64
png_str = png_str.removeprefix("data:image/png;base64,")
mask_data = base64.b64decode(png_str)
mask = Image.open(io.BytesIO(mask_data))
# Resize mask to match actual bounding box size
mask = mask.resize((x1 - x0, y1 - y0), Image.Resampling.BILINEAR)
# Save mask with original image coordinates
mask_filename = f"{item['label']}_{i}_mask.png"
mask_path = os.path.join(output_dir, mask_filename)
mask.save(mask_path)
# Create a full-size mask for the original image
full_mask = Image.new('L', (original_width, original_height), 0)
full_mask.paste(mask, (x0, y0))
full_mask.save(mask_path)
print(f"Saved mask for {item['label']} at coordinates ({x0}, {y0}, {x1}, {y1})")
# After getting items from JSON response
processed_items = []
for item in items:
# Store original label
original_label = item['label']
# Add the item with its original label to processed items
processed_items.append({
'label': original_label,
'box_2d': item['box_2d'],
'mask': item['mask']
})
items = processed_items
except Exception as e:
print(f"Error in extract_segmentation_masks: {str(e)}")
import traceback
traceback.print_exc()
return []
# Example usage
if __name__ == "__main__":
extract_segmentation_masks("path/to/image.png")