Skip to content

Commit f97adaf

Browse files
authored
Merge pull request #3292 from lllyasviel/develop
Release v2.5.0
2 parents 5a71495 + 97a8475 commit f97adaf

39 files changed

+2754
-927
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ __pycache__
1010
*.partial
1111
*.onnx
1212
sorted_styles.json
13+
hash_cache.txt
1314
/input
1415
/cache
1516
/language/default.json

args_manager.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@
2828
args_parser.parser.add_argument("--disable-preset-download", action='store_true',
2929
help="Disables downloading models for presets", default=False)
3030

31-
args_parser.parser.add_argument("--enable-describe-uov-image", action='store_true',
32-
help="Disables automatic description of uov images when prompt is empty", default=False)
31+
args_parser.parser.add_argument("--enable-auto-describe-image", action='store_true',
32+
help="Enables automatic description of uov and enhance image when prompt is empty", default=False)
3333

3434
args_parser.parser.add_argument("--always-download-new-model", action='store_true',
35-
help="Always download newer models ", default=False)
35+
help="Always download newer models", default=False)
36+
37+
args_parser.parser.add_argument("--rebuild-hash-cache", help="Generates missing model and LoRA hashes.",
38+
type=int, nargs="?", metavar="CPU_NUM_THREADS", const=-1)
3639

3740
args_parser.parser.set_defaults(
3841
disable_cuda_malloc=True,

css/style.css

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ div:has(> #positive_prompt) {
9999
}
100100

101101
.advanced_check_row {
102-
width: 250px !important;
102+
width: 330px !important;
103103
}
104104

105105
.min_check {

experiments_mask_generation.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# https://github.com/sail-sg/EditAnything/blob/main/sam2groundingdino_edit.py
2+
3+
import numpy as np
4+
from PIL import Image
5+
6+
from extras.inpaint_mask import SAMOptions, generate_mask_from_image
7+
8+
original_image = Image.open('cat.webp')
9+
image = np.array(original_image, dtype=np.uint8)
10+
11+
sam_options = SAMOptions(
12+
dino_prompt='eye',
13+
dino_box_threshold=0.3,
14+
dino_text_threshold=0.25,
15+
dino_erode_or_dilate=0,
16+
dino_debug=False,
17+
max_detections=2,
18+
model_type='vit_b'
19+
)
20+
21+
mask_image, _, _, _ = generate_mask_from_image(image, sam_options=sam_options)
22+
23+
merged_masks_img = Image.fromarray(mask_image)
24+
merged_masks_img.show()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
batch_size = 1
2+
modelname = "groundingdino"
3+
backbone = "swin_T_224_1k"
4+
position_embedding = "sine"
5+
pe_temperatureH = 20
6+
pe_temperatureW = 20
7+
return_interm_indices = [1, 2, 3]
8+
backbone_freeze_keywords = None
9+
enc_layers = 6
10+
dec_layers = 6
11+
pre_norm = False
12+
dim_feedforward = 2048
13+
hidden_dim = 256
14+
dropout = 0.0
15+
nheads = 8
16+
num_queries = 900
17+
query_dim = 4
18+
num_patterns = 0
19+
num_feature_levels = 4
20+
enc_n_points = 4
21+
dec_n_points = 4
22+
two_stage_type = "standard"
23+
two_stage_bbox_embed_share = False
24+
two_stage_class_embed_share = False
25+
transformer_activation = "relu"
26+
dec_pred_bbox_embed_share = True
27+
dn_box_noise_scale = 1.0
28+
dn_label_noise_ratio = 0.5
29+
dn_label_coef = 1.0
30+
dn_bbox_coef = 1.0
31+
embed_init_tgt = True
32+
dn_labelbook_size = 2000
33+
max_text_len = 256
34+
text_encoder_type = "bert-base-uncased"
35+
use_text_enhancer = True
36+
use_fusion_layer = True
37+
use_checkpoint = True
38+
use_transformer_ckpt = True
39+
use_text_cross_attention = True
40+
text_dropout = 0.0
41+
fusion_dropout = 0.0
42+
fusion_droppath = 0.1
43+
sub_sentence_present = True
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from typing import Tuple, List
2+
3+
import ldm_patched.modules.model_management as model_management
4+
from ldm_patched.modules.model_patcher import ModelPatcher
5+
from modules.config import path_inpaint
6+
from modules.model_loader import load_file_from_url
7+
8+
import numpy as np
9+
import supervision as sv
10+
import torch
11+
from groundingdino.util.inference import Model
12+
from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap
13+
14+
15+
class GroundingDinoModel(Model):
16+
def __init__(self):
17+
self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py'
18+
self.model = None
19+
self.load_device = torch.device('cpu')
20+
self.offload_device = torch.device('cpu')
21+
22+
@torch.no_grad()
23+
@torch.inference_mode()
24+
def predict_with_caption(
25+
self,
26+
image: np.ndarray,
27+
caption: str,
28+
box_threshold: float = 0.35,
29+
text_threshold: float = 0.25
30+
) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]:
31+
if self.model is None:
32+
filename = load_file_from_url(
33+
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
34+
file_name='groundingdino_swint_ogc.pth',
35+
model_dir=path_inpaint)
36+
model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename)
37+
38+
self.load_device = model_management.text_encoder_device()
39+
self.offload_device = model_management.text_encoder_offload_device()
40+
41+
model.to(self.offload_device)
42+
43+
self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)
44+
45+
model_management.load_model_gpu(self.model)
46+
47+
processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device)
48+
boxes, logits, phrases = predict(
49+
model=self.model,
50+
image=processed_image,
51+
caption=caption,
52+
box_threshold=box_threshold,
53+
text_threshold=text_threshold,
54+
device=self.load_device)
55+
source_h, source_w, _ = image.shape
56+
detections = GroundingDinoModel.post_process_result(
57+
source_h=source_h,
58+
source_w=source_w,
59+
boxes=boxes,
60+
logits=logits)
61+
return detections, boxes, logits, phrases
62+
63+
64+
def predict(
65+
model,
66+
image: torch.Tensor,
67+
caption: str,
68+
box_threshold: float,
69+
text_threshold: float,
70+
device: str = "cuda"
71+
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
72+
caption = preprocess_caption(caption=caption)
73+
74+
# override to use model wrapped by patcher
75+
model = model.model.to(device)
76+
image = image.to(device)
77+
78+
with torch.no_grad():
79+
outputs = model(image[None], captions=[caption])
80+
81+
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
82+
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
83+
84+
mask = prediction_logits.max(dim=1)[0] > box_threshold
85+
logits = prediction_logits[mask] # logits.shape = (n, 256)
86+
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
87+
88+
tokenizer = model.tokenizer
89+
tokenized = tokenizer(caption)
90+
91+
phrases = [
92+
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
93+
for logit
94+
in logits
95+
]
96+
97+
return boxes, logits.max(dim=1)[0], phrases
98+
99+
100+
default_groundingdino = GroundingDinoModel().predict_with_caption

extras/censor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def censor(self, images: list | np.ndarray) -> list | np.ndarray:
4141
model_management.load_model_gpu(self.safety_checker_model)
4242

4343
single = False
44-
if not isinstance(images, list) or isinstance(images, np.ndarray):
44+
if not isinstance(images, (list, np.ndarray)):
4545
images = [images]
4646
single = True
4747

extras/inpaint_mask.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import sys
2+
3+
import modules.config
4+
import numpy as np
5+
import torch
6+
from extras.GroundingDINO.util.inference import default_groundingdino
7+
from extras.sam.predictor import SamPredictor
8+
from rembg import remove, new_session
9+
from segment_anything import sam_model_registry
10+
from segment_anything.utils.amg import remove_small_regions
11+
12+
13+
class SAMOptions:
14+
def __init__(self,
15+
# GroundingDINO
16+
dino_prompt: str = '',
17+
dino_box_threshold=0.3,
18+
dino_text_threshold=0.25,
19+
dino_erode_or_dilate=0,
20+
dino_debug=False,
21+
22+
# SAM
23+
max_detections=2,
24+
model_type='vit_b'
25+
):
26+
self.dino_prompt = dino_prompt
27+
self.dino_box_threshold = dino_box_threshold
28+
self.dino_text_threshold = dino_text_threshold
29+
self.dino_erode_or_dilate = dino_erode_or_dilate
30+
self.dino_debug = dino_debug
31+
self.max_detections = max_detections
32+
self.model_type = model_type
33+
34+
35+
def optimize_masks(masks: torch.Tensor) -> torch.Tensor:
36+
"""
37+
removes small disconnected regions and holes
38+
"""
39+
fine_masks = []
40+
for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w]
41+
fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0])
42+
masks = np.stack(fine_masks, axis=0)[:, np.newaxis]
43+
return torch.from_numpy(masks)
44+
45+
46+
def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None,
47+
sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]:
48+
dino_detection_count = 0
49+
sam_detection_count = 0
50+
sam_detection_on_mask_count = 0
51+
52+
if image is None:
53+
return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
54+
55+
if extras is None:
56+
extras = {}
57+
58+
if 'image' in image:
59+
image = image['image']
60+
61+
if mask_model != 'sam' or sam_options is None:
62+
result = remove(
63+
image,
64+
session=new_session(mask_model, **extras),
65+
only_mask=True,
66+
**extras
67+
)
68+
69+
return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
70+
71+
detections, boxes, logits, phrases = default_groundingdino(
72+
image=image,
73+
caption=sam_options.dino_prompt,
74+
box_threshold=sam_options.dino_box_threshold,
75+
text_threshold=sam_options.dino_text_threshold
76+
)
77+
78+
H, W = image.shape[0], image.shape[1]
79+
boxes = boxes * torch.Tensor([W, H, W, H])
80+
boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2
81+
boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2]
82+
83+
sam_checkpoint = modules.config.download_sam_model(sam_options.model_type)
84+
sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint)
85+
86+
sam_predictor = SamPredictor(sam)
87+
final_mask_tensor = torch.zeros((image.shape[0], image.shape[1]))
88+
dino_detection_count = boxes.size(0)
89+
90+
if dino_detection_count > 0:
91+
sam_predictor.set_image(image)
92+
93+
if sam_options.dino_erode_or_dilate != 0:
94+
for index in range(boxes.size(0)):
95+
assert boxes.size(1) == 4
96+
boxes[index][0] -= sam_options.dino_erode_or_dilate
97+
boxes[index][1] -= sam_options.dino_erode_or_dilate
98+
boxes[index][2] += sam_options.dino_erode_or_dilate
99+
boxes[index][3] += sam_options.dino_erode_or_dilate
100+
101+
if sam_options.dino_debug:
102+
from PIL import ImageDraw, Image
103+
debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
104+
draw = ImageDraw.Draw(debug_dino_image)
105+
for box in boxes.numpy():
106+
draw.rectangle(box.tolist(), fill="white")
107+
return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count
108+
109+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
110+
masks, _, _ = sam_predictor.predict_torch(
111+
point_coords=None,
112+
point_labels=None,
113+
boxes=transformed_boxes,
114+
multimask_output=False,
115+
)
116+
117+
masks = optimize_masks(masks)
118+
sam_detection_count = len(masks)
119+
if sam_options.max_detections == 0:
120+
sam_options.max_detections = sys.maxsize
121+
sam_objects = min(len(logits), sam_options.max_detections)
122+
for obj_ind in range(sam_objects):
123+
mask_tensor = masks[obj_ind][0]
124+
final_mask_tensor += mask_tensor
125+
sam_detection_on_mask_count += 1
126+
127+
final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy()
128+
mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255
129+
mask_image = np.array(mask_image, dtype=np.uint8)
130+
return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count

0 commit comments

Comments
 (0)