Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions javascript/sam.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,27 @@ function switchToInpaintUpload() {
return arguments;
}

function samSpecialTabForUI() {
return 'sam_special_tab_for_ui' in opts && opts.sam_special_tab_for_ui;
}

function samTabPrefix() {
const tabs = gradioApp().querySelector('#tabs');
if (tabs) {
const buttons = tabs.querySelectorAll('button');
if (buttons) {
if (buttons[0].className.includes("selected")) {
return "txt2img_sam_"
} else if (buttons[1].className.includes("selected")) {
return "img2img_sam_"
if (samSpecialTabForUI()) {
return "img2img_sam_";
} else {
const tabs = gradioApp().querySelector('#tabs');
if (tabs) {
const buttons = tabs.querySelectorAll('button');
if (buttons) {
if (buttons[0].className.includes("selected")) {
return "txt2img_sam_";
} else if (buttons[1].className.includes("selected")) {
return "img2img_sam_";
}
}
}
return "_sam_";
}
return "_sam_"
}

function samImmediatelyGenerate() {
Expand Down Expand Up @@ -187,4 +195,23 @@ onUiUpdate(() => {
samPrevImg[samTabPrefix()] = null;
}
}
})
});


async function samWaitForOpts() {
for (; ;) {
if (window.opts && Object.keys(window.opts).length) {
return window.opts;
}
await new Promise(resolve => setTimeout(resolve, 100));
}
}

onUiLoaded(async () => {
const opts = await samWaitForOpts();
if (samSpecialTabForUI()) {
let accordion = gradioApp().getElementById('segment_anything_accordion_img2img');
let tab = gradioApp().getElementById('tab_segment_anything');
tab.appendChild(accordion)
}
});
10 changes: 2 additions & 8 deletions scripts/api.py → lib_segment_anything/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np

from modules.api.api import encode_pil_to_base64, decode_base64_to_image
from scripts.sam import sam_predict, dino_predict, update_mask, cnet_seg, categorical_mask
from scripts.sam import sam_model_list
from lib_segment_anything.sam import (sam_predict, dino_predict, update_mask,
cnet_seg, categorical_mask, sam_model_list)


def decode_to_pil(image):
Expand Down Expand Up @@ -238,9 +238,3 @@ async def api_category_mask(payload: CategoryMaskRequest = Body(...),
result["resized_input"] = encode_to_base64(resized_input_img)
return result


try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(sam_api)
except:
print("SAM Web UI API failed to initialize")
Loading