33
44import torch
55import torchvision .transforms .functional as FF
6- from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
6+ from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer , CLIPVisionModelWithProjection
77
88from diffusers import StableDiffusionPipeline
99from diffusers .models import AutoencoderKL , UNet2DConditionModel
1010from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
1111from diffusers .schedulers import KarrasDiffusionSchedulers
12- from diffusers .utils import USE_PEFT_BACKEND
1312
1413
1514try :
1615 from compel import Compel
1716except ImportError :
1817 Compel = None
1918
19+ KBASE = "ADDBASE"
2020KCOMM = "ADDCOMM"
2121KBRK = "BREAK"
2222
@@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
3434
3535 Optional
3636 rp_args["save_mask"]: True/False (save masks in prompt mode)
37+ rp_args["power"]: int (power for attention maps in prompt mode)
38+ rp_args["base_ratio"]:
39+ float (Sets the ratio of the base prompt)
40+ ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
41+ [Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
3742
3843 Pipeline for text-to-image generation using Stable Diffusion.
3944
@@ -70,6 +75,7 @@ def __init__(
7075 scheduler : KarrasDiffusionSchedulers ,
7176 safety_checker : StableDiffusionSafetyChecker ,
7277 feature_extractor : CLIPImageProcessor ,
78+ image_encoder : CLIPVisionModelWithProjection = None ,
7379 requires_safety_checker : bool = True ,
7480 ):
7581 super ().__init__ (
@@ -80,6 +86,7 @@ def __init__(
8086 scheduler ,
8187 safety_checker ,
8288 feature_extractor ,
89+ image_encoder ,
8390 requires_safety_checker ,
8491 )
8592 self .register_modules (
@@ -90,6 +97,7 @@ def __init__(
9097 scheduler = scheduler ,
9198 safety_checker = safety_checker ,
9299 feature_extractor = feature_extractor ,
100+ image_encoder = image_encoder ,
93101 )
94102
95103 @torch .no_grad ()
@@ -110,17 +118,40 @@ def __call__(
110118 rp_args : Dict [str , str ] = None ,
111119 ):
112120 active = KBRK in prompt [0 ] if isinstance (prompt , list ) else KBRK in prompt
121+ use_base = KBASE in prompt [0 ] if isinstance (prompt , list ) else KBASE in prompt
113122 if negative_prompt is None :
114123 negative_prompt = "" if isinstance (prompt , str ) else ["" ] * len (prompt )
115124
116125 device = self ._execution_device
117126 regions = 0
118127
128+ self .base_ratio = float (rp_args ["base_ratio" ]) if "base_ratio" in rp_args else 0.0
119129 self .power = int (rp_args ["power" ]) if "power" in rp_args else 1
120130
121131 prompts = prompt if isinstance (prompt , list ) else [prompt ]
122- n_prompts = negative_prompt if isinstance (prompt , str ) else [negative_prompt ]
132+ n_prompts = negative_prompt if isinstance (prompt , list ) else [negative_prompt ]
123133 self .batch = batch = num_images_per_prompt * len (prompts )
134+
135+ if use_base :
136+ bases = prompts .copy ()
137+ n_bases = n_prompts .copy ()
138+
139+ for i , prompt in enumerate (prompts ):
140+ parts = prompt .split (KBASE )
141+ if len (parts ) == 2 :
142+ bases [i ], prompts [i ] = parts
143+ elif len (parts ) > 2 :
144+ raise ValueError (f"Multiple instances of { KBASE } found in prompt: { prompt } " )
145+ for i , prompt in enumerate (n_prompts ):
146+ n_parts = prompt .split (KBASE )
147+ if len (n_parts ) == 2 :
148+ n_bases [i ], n_prompts [i ] = n_parts
149+ elif len (n_parts ) > 2 :
150+ raise ValueError (f"Multiple instances of { KBASE } found in negative prompt: { prompt } " )
151+
152+ all_bases_cn , _ = promptsmaker (bases , num_images_per_prompt )
153+ all_n_bases_cn , _ = promptsmaker (n_bases , num_images_per_prompt )
154+
124155 all_prompts_cn , all_prompts_p = promptsmaker (prompts , num_images_per_prompt )
125156 all_n_prompts_cn , _ = promptsmaker (n_prompts , num_images_per_prompt )
126157
@@ -137,8 +168,16 @@ def getcompelembs(prps):
137168
138169 conds = getcompelembs (all_prompts_cn )
139170 unconds = getcompelembs (all_n_prompts_cn )
140- embs = getcompelembs (prompts )
141- n_embs = getcompelembs (n_prompts )
171+ base_embs = getcompelembs (all_bases_cn ) if use_base else None
172+ base_n_embs = getcompelembs (all_n_bases_cn ) if use_base else None
173+ # When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
174+ embs = getcompelembs (prompts ) if not use_base else base_embs
175+ n_embs = getcompelembs (n_prompts ) if not use_base else base_n_embs
176+
177+ if use_base and self .base_ratio > 0 :
178+ conds = self .base_ratio * base_embs + (1 - self .base_ratio ) * conds
179+ unconds = self .base_ratio * base_n_embs + (1 - self .base_ratio ) * unconds
180+
142181 prompt = negative_prompt = None
143182 else :
144183 conds = self .encode_prompt (prompts , device , 1 , True )[0 ]
@@ -147,6 +186,18 @@ def getcompelembs(prps):
147186 if equal
148187 else self .encode_prompt (all_n_prompts_cn , device , 1 , True )[0 ]
149188 )
189+
190+ if use_base and self .base_ratio > 0 :
191+ base_embs = self .encode_prompt (bases , device , 1 , True )[0 ]
192+ base_n_embs = (
193+ self .encode_prompt (n_bases , device , 1 , True )[0 ]
194+ if equal
195+ else self .encode_prompt (all_n_bases_cn , device , 1 , True )[0 ]
196+ )
197+
198+ conds = self .base_ratio * base_embs + (1 - self .base_ratio ) * conds
199+ unconds = self .base_ratio * base_n_embs + (1 - self .base_ratio ) * unconds
200+
150201 embs = n_embs = None
151202
152203 if not active :
@@ -225,8 +276,6 @@ def forward(
225276
226277 residual = hidden_states
227278
228- args = () if USE_PEFT_BACKEND else (scale ,)
229-
230279 if attn .spatial_norm is not None :
231280 hidden_states = attn .spatial_norm (hidden_states , temb )
232281
@@ -247,16 +296,15 @@ def forward(
247296 if attn .group_norm is not None :
248297 hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
249298
250- args = () if USE_PEFT_BACKEND else (scale ,)
251- query = attn .to_q (hidden_states , * args )
299+ query = attn .to_q (hidden_states )
252300
253301 if encoder_hidden_states is None :
254302 encoder_hidden_states = hidden_states
255303 elif attn .norm_cross :
256304 encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
257305
258- key = attn .to_k (encoder_hidden_states , * args )
259- value = attn .to_v (encoder_hidden_states , * args )
306+ key = attn .to_k (encoder_hidden_states )
307+ value = attn .to_v (encoder_hidden_states )
260308
261309 inner_dim = key .shape [- 1 ]
262310 head_dim = inner_dim // attn .heads
@@ -283,7 +331,7 @@ def forward(
283331 hidden_states = hidden_states .to (query .dtype )
284332
285333 # linear proj
286- hidden_states = attn .to_out [0 ](hidden_states , * args )
334+ hidden_states = attn .to_out [0 ](hidden_states )
287335 # dropout
288336 hidden_states = attn .to_out [1 ](hidden_states )
289337
@@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
410458 add = ""
411459 if KCOMM in prompt :
412460 add , prompt = prompt .split (KCOMM )
413- add = add + " "
414- prompts = prompt .split (KBRK )
415- out_p .append ([add + p for p in prompts ])
461+ add = add . strip () + " "
462+ prompts = [ p . strip () for p in prompt .split (KBRK )]
463+ out_p .append ([add + p for i , p in enumerate ( prompts ) ])
416464 out = [None ] * batch * len (out_p [0 ]) * len (out_p )
417465 for p , prs in enumerate (out_p ): # inputs prompts
418466 for r , pr in enumerate (prs ): # prompts for regions
@@ -449,7 +497,6 @@ def startend(cells, array):
449497 add = []
450498 startend (add , inratios [1 :])
451499 icells .append (add )
452-
453500 return ocells , icells , sum (len (cell ) for cell in icells )
454501
455502
0 commit comments