Skip to content

Commit 2fca093

Browse files
committed
refactoring
1 parent 982138c commit 2fca093

File tree

4 files changed

+943
-625
lines changed

4 files changed

+943
-625
lines changed

javascript/sam.js

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,27 @@ function switchToInpaintUpload() {
2121
return arguments;
2222
}
2323

24+
function samSpecialTabForUI() {
25+
return 'sam_special_tab_for_ui' in opts && opts.sam_special_tab_for_ui;
26+
}
27+
2428
function samTabPrefix() {
25-
const tabs = gradioApp().querySelector('#tabs');
26-
if (tabs) {
27-
const buttons = tabs.querySelectorAll('button');
28-
if (buttons) {
29-
if (buttons[0].className.includes("selected")) {
30-
return "txt2img_sam_"
31-
} else if (buttons[1].className.includes("selected")) {
32-
return "img2img_sam_"
29+
if (samSpecialTabForUI()) {
30+
return "img2img_sam_";
31+
} else {
32+
const tabs = gradioApp().querySelector('#tabs');
33+
if (tabs) {
34+
const buttons = tabs.querySelectorAll('button');
35+
if (buttons) {
36+
if (buttons[0].className.includes("selected")) {
37+
return "txt2img_sam_";
38+
} else if (buttons[1].className.includes("selected")) {
39+
return "img2img_sam_";
40+
}
3341
}
3442
}
43+
return "_sam_";
3544
}
36-
return "_sam_"
3745
}
3846

3947
function samImmediatelyGenerate() {
@@ -187,4 +195,23 @@ onUiUpdate(() => {
187195
samPrevImg[samTabPrefix()] = null;
188196
}
189197
}
190-
})
198+
});
199+
200+
201+
async function samWaitForOpts() {
202+
for (; ;) {
203+
if (window.opts && Object.keys(window.opts).length) {
204+
return window.opts;
205+
}
206+
await new Promise(resolve => setTimeout(resolve, 100));
207+
}
208+
}
209+
210+
onUiLoaded(async () => {
211+
const opts = await samWaitForOpts();
212+
if (samSpecialTabForUI()) {
213+
let accordion = gradioApp().getElementById('segment_anything_accordion_img2img');
214+
let tab = gradioApp().getElementById('tab_segment_anything');
215+
tab.appendChild(accordion)
216+
}
217+
});

lib_segment_anything/api.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import os
2+
from fastapi import FastAPI, Body
3+
from pydantic import BaseModel
4+
from typing import Any, Optional, List
5+
import gradio as gr
6+
from PIL import Image
7+
import numpy as np
8+
9+
from modules.api.api import encode_pil_to_base64, decode_base64_to_image
10+
from lib_segment_anything.sam import (sam_predict, dino_predict, update_mask,
11+
cnet_seg, categorical_mask, sam_model_list)
12+
13+
14+
def decode_to_pil(image):
15+
if os.path.exists(image):
16+
return Image.open(image)
17+
elif type(image) is str:
18+
return decode_base64_to_image(image)
19+
elif type(image) is Image.Image:
20+
return image
21+
elif type(image) is np.ndarray:
22+
return Image.fromarray(image)
23+
else:
24+
Exception("Not an image")
25+
26+
27+
def encode_to_base64(image):
28+
if type(image) is str:
29+
return image
30+
elif type(image) is Image.Image:
31+
return encode_pil_to_base64(image).decode()
32+
elif type(image) is np.ndarray:
33+
pil = Image.fromarray(image)
34+
return encode_pil_to_base64(pil).decode()
35+
else:
36+
Exception("Invalid type")
37+
38+
39+
def sam_api(_: gr.Blocks, app: FastAPI):
40+
@app.get("/sam/heartbeat")
41+
async def heartbeat():
42+
return {
43+
"msg": "Success!"
44+
}
45+
46+
@app.get("/sam/sam-model", description='Query available SAM model')
47+
async def api_sam_model() -> List[str]:
48+
return sam_model_list
49+
50+
class SamPredictRequest(BaseModel):
51+
sam_model_name: str = "sam_vit_h_4b8939.pth"
52+
input_image: str
53+
sam_positive_points: List[List[float]] = []
54+
sam_negative_points: List[List[float]] = []
55+
dino_enabled: bool = False
56+
dino_model_name: Optional[str] = "GroundingDINO_SwinT_OGC (694MB)"
57+
dino_text_prompt: Optional[str] = None
58+
dino_box_threshold: Optional[float] = 0.3
59+
dino_preview_checkbox: bool = False
60+
dino_preview_boxes_selection: Optional[List[int]] = None
61+
62+
@app.post("/sam/sam-predict")
63+
async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any:
64+
print(f"SAM API /sam/sam-predict received request")
65+
payload.input_image = decode_to_pil(payload.input_image).convert('RGBA')
66+
sam_output_mask_gallery, sam_message = sam_predict(
67+
payload.sam_model_name,
68+
payload.input_image,
69+
payload.sam_positive_points,
70+
payload.sam_negative_points,
71+
payload.dino_enabled,
72+
payload.dino_model_name,
73+
payload.dino_text_prompt,
74+
payload.dino_box_threshold,
75+
payload.dino_preview_checkbox,
76+
payload.dino_preview_boxes_selection)
77+
print(f"SAM API /sam/sam-predict finished with message: {sam_message}")
78+
result = {
79+
"msg": sam_message,
80+
}
81+
if len(sam_output_mask_gallery) == 9:
82+
result["blended_images"] = list(map(encode_to_base64, sam_output_mask_gallery[:3]))
83+
result["masks"] = list(map(encode_to_base64, sam_output_mask_gallery[3:6]))
84+
result["masked_images"] = list(map(encode_to_base64, sam_output_mask_gallery[6:]))
85+
return result
86+
87+
class DINOPredictRequest(BaseModel):
88+
input_image: str
89+
dino_model_name: str = "GroundingDINO_SwinT_OGC (694MB)"
90+
text_prompt: str
91+
box_threshold: float = 0.3
92+
93+
@app.post("/sam/dino-predict")
94+
async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any:
95+
print(f"SAM API /sam/dino-predict received request")
96+
payload.input_image = decode_to_pil(payload.input_image)
97+
dino_output_img, _, dino_msg = dino_predict(
98+
payload.input_image,
99+
payload.dino_model_name,
100+
payload.text_prompt,
101+
payload.box_threshold)
102+
if "value" in dino_msg:
103+
dino_msg = dino_msg["value"]
104+
else:
105+
dino_msg = "Done"
106+
print(f"SAM API /sam/dino-predict finished with message: {dino_msg}")
107+
return {
108+
"msg": dino_msg,
109+
"image_with_box": encode_to_base64(dino_output_img) if dino_output_img is not None else None,
110+
}
111+
112+
class DilateMaskRequest(BaseModel):
113+
input_image: str
114+
mask: str
115+
dilate_amount: int = 10
116+
117+
@app.post("/sam/dilate-mask")
118+
async def api_dilate_mask(payload: DilateMaskRequest = Body(...)) -> Any:
119+
print(f"SAM API /sam/dilate-mask received request")
120+
payload.input_image = decode_to_pil(payload.input_image).convert("RGBA")
121+
payload.mask = decode_to_pil(payload.mask)
122+
dilate_result = list(map(encode_to_base64, update_mask(payload.mask, 0, payload.dilate_amount, payload.input_image)))
123+
print(f"SAM API /sam/dilate-mask finished")
124+
return {"blended_image": dilate_result[0], "mask": dilate_result[1], "masked_image": dilate_result[2]}
125+
126+
127+
class AutoSAMConfig(BaseModel):
128+
points_per_side: Optional[int] = 32
129+
points_per_batch: int = 64
130+
pred_iou_thresh: float = 0.88
131+
stability_score_thresh: float = 0.95
132+
stability_score_offset: float = 1.0
133+
box_nms_thresh: float = 0.7
134+
crop_n_layers: int = 0
135+
crop_nms_thresh: float = 0.7
136+
crop_overlap_ratio: float = 512 / 1500
137+
crop_n_points_downscale_factor: int = 1
138+
min_mask_region_area: int = 0
139+
140+
class ControlNetSegRequest(BaseModel):
141+
sam_model_name: str = "sam_vit_h_4b8939.pth"
142+
input_image: str
143+
processor: str = "seg_ofade20k"
144+
processor_res: int = 512
145+
pixel_perfect: bool = False
146+
resize_mode: Optional[int] = 1 # 0: just resize, 1: crop and resize, 2: resize and fill
147+
target_W: Optional[int] = None
148+
target_H: Optional[int] = None
149+
150+
@app.post("/sam/controlnet-seg")
151+
async def api_controlnet_seg(payload: ControlNetSegRequest = Body(...),
152+
autosam_conf: AutoSAMConfig = Body(...)) -> Any:
153+
print(f"SAM API /sam/controlnet-seg received request")
154+
payload.input_image = decode_to_pil(payload.input_image)
155+
cnet_seg_img, cnet_seg_msg = cnet_seg(
156+
payload.sam_model_name,
157+
payload.input_image,
158+
payload.processor,
159+
payload.processor_res,
160+
payload.pixel_perfect,
161+
payload.resize_mode,
162+
payload.target_W,
163+
payload.target_H,
164+
autosam_conf.points_per_side,
165+
autosam_conf.points_per_batch,
166+
autosam_conf.pred_iou_thresh,
167+
autosam_conf.stability_score_thresh,
168+
autosam_conf.stability_score_offset,
169+
autosam_conf.box_nms_thresh,
170+
autosam_conf.crop_n_layers,
171+
autosam_conf.crop_nms_thresh,
172+
autosam_conf.crop_overlap_ratio,
173+
autosam_conf.crop_n_points_downscale_factor,
174+
autosam_conf.min_mask_region_area)
175+
cnet_seg_img = list(map(encode_to_base64, cnet_seg_img))
176+
print(f"SAM API /sam/controlnet-seg finished with message {cnet_seg_msg}")
177+
result = {
178+
"msg": cnet_seg_msg,
179+
}
180+
if len(cnet_seg_img) == 3:
181+
result["blended_images"] = cnet_seg_img[0]
182+
result["random_seg"] = cnet_seg_img[1]
183+
result["edit_anything_control"] = cnet_seg_img[2]
184+
elif len(cnet_seg_img) == 4:
185+
result["sem_presam"] = cnet_seg_img[0]
186+
result["sem_postsam"] = cnet_seg_img[1]
187+
result["blended_presam"] = cnet_seg_img[2]
188+
result["blended_postsam"] = cnet_seg_img[3]
189+
return result
190+
191+
class CategoryMaskRequest(BaseModel):
192+
sam_model_name: str = "sam_vit_h_4b8939.pth"
193+
processor: str = "seg_ofade20k"
194+
processor_res: int = 512
195+
pixel_perfect: bool = False
196+
resize_mode: Optional[int] = 1
197+
target_W: Optional[int] = None
198+
target_H: Optional[int] = None
199+
category: str
200+
input_image: str
201+
202+
@app.post("/sam/category-mask")
203+
async def api_category_mask(payload: CategoryMaskRequest = Body(...),
204+
autosam_conf: AutoSAMConfig = Body(...)) -> Any:
205+
print(f"SAM API /sam/category-mask received request")
206+
payload.input_image = decode_to_pil(payload.input_image)
207+
category_mask_img, category_mask_msg, resized_input_img = categorical_mask(
208+
payload.sam_model_name,
209+
payload.processor,
210+
payload.processor_res,
211+
payload.pixel_perfect,
212+
payload.resize_mode,
213+
payload.target_W,
214+
payload.target_H,
215+
payload.category,
216+
payload.input_image,
217+
autosam_conf.points_per_side,
218+
autosam_conf.points_per_batch,
219+
autosam_conf.pred_iou_thresh,
220+
autosam_conf.stability_score_thresh,
221+
autosam_conf.stability_score_offset,
222+
autosam_conf.box_nms_thresh,
223+
autosam_conf.crop_n_layers,
224+
autosam_conf.crop_nms_thresh,
225+
autosam_conf.crop_overlap_ratio,
226+
autosam_conf.crop_n_points_downscale_factor,
227+
autosam_conf.min_mask_region_area)
228+
category_mask_img = list(map(encode_to_base64, category_mask_img))
229+
print(f"SAM API /sam/category-mask finished with message {category_mask_msg}")
230+
result = {
231+
"msg": category_mask_msg,
232+
}
233+
if len(category_mask_img) == 3:
234+
result["blended_image"] = category_mask_img[0]
235+
result["mask"] = category_mask_img[1]
236+
result["masked_image"] = category_mask_img[2]
237+
if resized_input_img is not None:
238+
result["resized_input"] = encode_to_base64(resized_input_img)
239+
return result
240+

0 commit comments

Comments
 (0)