Skip to content

Commit 6af2097

Browse files
authored
Merge branch 'main' into flux-control-lora
2 parents 908d151 + c96bfa5 commit 6af2097

22 files changed

+2142
-229
lines changed

docs/source/en/api/pipelines/flux.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ image.save("output.png")
148148
**Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.
149149

150150
```python
151-
# !pip install git+https://github.com/asomoza/image_gen_aux.git
151+
# !pip install git+https://github.com/huggingface/image_gen_aux
152152
import torch
153153
from diffusers import FluxControlPipeline, FluxTransformer2DModel
154154
from diffusers.utils import load_image

examples/community/README.md

Lines changed: 108 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2619,16 +2619,17 @@ for obj in range(bs):
26192619

26202620
### Stable Diffusion XL Reference
26212621

2622-
This pipeline uses the Reference. Refer to the [stable_diffusion_reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference).
2622+
This pipeline uses the Reference. Refer to the [Stable Diffusion Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference) section for more information.
26232623

26242624
```py
26252625
import torch
2626-
from PIL import Image
2626+
# from diffusers import DiffusionPipeline
26272627
from diffusers.utils import load_image
2628-
from diffusers import DiffusionPipeline
26292628
from diffusers.schedulers import UniPCMultistepScheduler
26302629

2631-
input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
2630+
from .stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
2631+
2632+
input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg")
26322633

26332634
# pipe = DiffusionPipeline.from_pretrained(
26342635
# "stabilityai/stable-diffusion-xl-base-1.0",
@@ -2646,22 +2647,22 @@ pipe = StableDiffusionXLReferencePipeline.from_pretrained(
26462647
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
26472648

26482649
result_img = pipe(ref_image=input_image,
2649-
prompt="1girl",
2650+
prompt="a dog",
26502651
num_inference_steps=20,
26512652
reference_attn=True,
26522653
reference_adain=True).images[0]
26532654
```
26542655

26552656
Reference Image
26562657

2657-
![reference_image](https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png)
2658+
![reference_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg)
26582659

26592660
Output Image
26602661

2661-
`prompt: 1 girl`
2662+
`prompt: a dog`
26622663

2663-
`reference_attn=True, reference_adain=True, num_inference_steps=20`
2664-
![Output_image](https://github.com/zideliu/diffusers/assets/34944964/743848da-a215-48f9-ae39-b5e2ae49fb13)
2664+
`reference_attn=False, reference_adain=True, num_inference_steps=20`
2665+
![Output_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_adain_dog.png)
26652666

26662667
Reference Image
26672668
![reference_image](https://github.com/huggingface/diffusers/assets/34944964/449bdab6-e744-4fb2-9620-d4068d9a741b)
@@ -2683,6 +2684,88 @@ Output Image
26832684
`reference_attn=True, reference_adain=True, num_inference_steps=20`
26842685
![output_image](https://github.com/huggingface/diffusers/assets/34944964/9b2f1aca-886f-49c3-89ec-d2031c8e3670)
26852686

2687+
### Stable Diffusion XL ControlNet Reference
2688+
2689+
This pipeline uses the Reference Control and with ControlNet. Refer to the [Stable Diffusion ControlNet Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-controlnet-reference) and [Stable Diffusion XL Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-xl-reference) sections for more information.
2690+
2691+
```py
2692+
from diffusers import ControlNetModel, AutoencoderKL
2693+
from diffusers.schedulers import UniPCMultistepScheduler
2694+
from diffusers.utils import load_image
2695+
import numpy as np
2696+
import torch
2697+
2698+
import cv2
2699+
from PIL import Image
2700+
2701+
from .stable_diffusion_xl_controlnet_reference import StableDiffusionXLControlNetReferencePipeline
2702+
2703+
# download an image
2704+
canny_image = load_image(
2705+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg"
2706+
)
2707+
2708+
ref_image = load_image(
2709+
"https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
2710+
)
2711+
2712+
# initialize the models and pipeline
2713+
controlnet_conditioning_scale = 0.5 # recommended for good generalization
2714+
controlnet = ControlNetModel.from_pretrained(
2715+
"diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
2716+
)
2717+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
2718+
pipe = StableDiffusionXLControlNetReferencePipeline.from_pretrained(
2719+
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
2720+
).to("cuda:0")
2721+
2722+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
2723+
2724+
# get canny image
2725+
image = np.array(canny_image)
2726+
image = cv2.Canny(image, 100, 200)
2727+
image = image[:, :, None]
2728+
image = np.concatenate([image, image, image], axis=2)
2729+
canny_image = Image.fromarray(image)
2730+
2731+
# generate image
2732+
image = pipe(
2733+
prompt="a cat",
2734+
num_inference_steps=20,
2735+
controlnet_conditioning_scale=controlnet_conditioning_scale,
2736+
image=canny_image,
2737+
ref_image=ref_image,
2738+
reference_attn=False,
2739+
reference_adain=True,
2740+
style_fidelity=1.0,
2741+
generator=torch.Generator("cuda").manual_seed(42)
2742+
).images[0]
2743+
```
2744+
2745+
Canny ControlNet Image
2746+
2747+
![canny_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg)
2748+
2749+
Reference Image
2750+
2751+
![ref_image](https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png)
2752+
2753+
Output Image
2754+
2755+
`prompt: a cat`
2756+
2757+
`reference_attn=True, reference_adain=True, num_inference_steps=20, style_fidelity=1.0`
2758+
2759+
![Output_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_attn_adain_canny_cat.png)
2760+
2761+
`reference_attn=False, reference_adain=True, num_inference_steps=20, style_fidelity=1.0`
2762+
2763+
![Output_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_adain_canny_cat.png)
2764+
2765+
`reference_attn=True, reference_adain=False, num_inference_steps=20, style_fidelity=1.0`
2766+
2767+
![Output_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_attn_canny_cat.png)
2768+
26862769
### Stable diffusion fabric pipeline
26872770

26882771
FABRIC approach applicable to a wide range of popular diffusion models, which exploits
@@ -3378,6 +3461,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK
33783461
best quality, 3persons in garden, an old man red suit
33793462
```
33803463

3464+
### Use base prompt
3465+
3466+
You can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first.
3467+
3468+
```
3469+
2d animation style ADDBASE
3470+
masterpiece, high quality ADDCOMM
3471+
(blue sky)++ BREAK
3472+
green hair twintail BREAK
3473+
book shelf BREAK
3474+
messy desk BREAK
3475+
orange++ dress and sofa
3476+
```
3477+
33813478
### Negative prompt
33823479

33833480
Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
@@ -3408,6 +3505,7 @@ pipe(prompt=prompt, rp_args=rp_args)
34083505
### Optional Parameters
34093506

34103507
- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
3508+
- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT`
34113509

34123510
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
34133511

@@ -4696,4 +4794,4 @@ with torch.no_grad():
46964794
```
46974795

46984796
In the folder examples/pixart there is also a script that can be used to train new models.
4699-
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
4797+
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.

examples/community/regional_prompting_stable_diffusion.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33

44
import torch
55
import torchvision.transforms.functional as FF
6-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
6+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
77

88
from diffusers import StableDiffusionPipeline
99
from diffusers.models import AutoencoderKL, UNet2DConditionModel
1010
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
1111
from diffusers.schedulers import KarrasDiffusionSchedulers
12-
from diffusers.utils import USE_PEFT_BACKEND
1312

1413

1514
try:
1615
from compel import Compel
1716
except ImportError:
1817
Compel = None
1918

19+
KBASE = "ADDBASE"
2020
KCOMM = "ADDCOMM"
2121
KBRK = "BREAK"
2222

@@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
3434
3535
Optional
3636
rp_args["save_mask"]: True/False (save masks in prompt mode)
37+
rp_args["power"]: int (power for attention maps in prompt mode)
38+
rp_args["base_ratio"]:
39+
float (Sets the ratio of the base prompt)
40+
ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
41+
[Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
3742
3843
Pipeline for text-to-image generation using Stable Diffusion.
3944
@@ -70,6 +75,7 @@ def __init__(
7075
scheduler: KarrasDiffusionSchedulers,
7176
safety_checker: StableDiffusionSafetyChecker,
7277
feature_extractor: CLIPImageProcessor,
78+
image_encoder: CLIPVisionModelWithProjection = None,
7379
requires_safety_checker: bool = True,
7480
):
7581
super().__init__(
@@ -80,6 +86,7 @@ def __init__(
8086
scheduler,
8187
safety_checker,
8288
feature_extractor,
89+
image_encoder,
8390
requires_safety_checker,
8491
)
8592
self.register_modules(
@@ -90,6 +97,7 @@ def __init__(
9097
scheduler=scheduler,
9198
safety_checker=safety_checker,
9299
feature_extractor=feature_extractor,
100+
image_encoder=image_encoder,
93101
)
94102

95103
@torch.no_grad()
@@ -110,17 +118,40 @@ def __call__(
110118
rp_args: Dict[str, str] = None,
111119
):
112120
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
121+
use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
113122
if negative_prompt is None:
114123
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
115124

116125
device = self._execution_device
117126
regions = 0
118127

128+
self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
119129
self.power = int(rp_args["power"]) if "power" in rp_args else 1
120130

121131
prompts = prompt if isinstance(prompt, list) else [prompt]
122-
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
132+
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
123133
self.batch = batch = num_images_per_prompt * len(prompts)
134+
135+
if use_base:
136+
bases = prompts.copy()
137+
n_bases = n_prompts.copy()
138+
139+
for i, prompt in enumerate(prompts):
140+
parts = prompt.split(KBASE)
141+
if len(parts) == 2:
142+
bases[i], prompts[i] = parts
143+
elif len(parts) > 2:
144+
raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
145+
for i, prompt in enumerate(n_prompts):
146+
n_parts = prompt.split(KBASE)
147+
if len(n_parts) == 2:
148+
n_bases[i], n_prompts[i] = n_parts
149+
elif len(n_parts) > 2:
150+
raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")
151+
152+
all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
153+
all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)
154+
124155
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
125156
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
126157

@@ -137,8 +168,16 @@ def getcompelembs(prps):
137168

138169
conds = getcompelembs(all_prompts_cn)
139170
unconds = getcompelembs(all_n_prompts_cn)
140-
embs = getcompelembs(prompts)
141-
n_embs = getcompelembs(n_prompts)
171+
base_embs = getcompelembs(all_bases_cn) if use_base else None
172+
base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
173+
# When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
174+
embs = getcompelembs(prompts) if not use_base else base_embs
175+
n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs
176+
177+
if use_base and self.base_ratio > 0:
178+
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
179+
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
180+
142181
prompt = negative_prompt = None
143182
else:
144183
conds = self.encode_prompt(prompts, device, 1, True)[0]
@@ -147,6 +186,18 @@ def getcompelembs(prps):
147186
if equal
148187
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
149188
)
189+
190+
if use_base and self.base_ratio > 0:
191+
base_embs = self.encode_prompt(bases, device, 1, True)[0]
192+
base_n_embs = (
193+
self.encode_prompt(n_bases, device, 1, True)[0]
194+
if equal
195+
else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
196+
)
197+
198+
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
199+
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
200+
150201
embs = n_embs = None
151202

152203
if not active:
@@ -225,8 +276,6 @@ def forward(
225276

226277
residual = hidden_states
227278

228-
args = () if USE_PEFT_BACKEND else (scale,)
229-
230279
if attn.spatial_norm is not None:
231280
hidden_states = attn.spatial_norm(hidden_states, temb)
232281

@@ -247,16 +296,15 @@ def forward(
247296
if attn.group_norm is not None:
248297
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
249298

250-
args = () if USE_PEFT_BACKEND else (scale,)
251-
query = attn.to_q(hidden_states, *args)
299+
query = attn.to_q(hidden_states)
252300

253301
if encoder_hidden_states is None:
254302
encoder_hidden_states = hidden_states
255303
elif attn.norm_cross:
256304
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
257305

258-
key = attn.to_k(encoder_hidden_states, *args)
259-
value = attn.to_v(encoder_hidden_states, *args)
306+
key = attn.to_k(encoder_hidden_states)
307+
value = attn.to_v(encoder_hidden_states)
260308

261309
inner_dim = key.shape[-1]
262310
head_dim = inner_dim // attn.heads
@@ -283,7 +331,7 @@ def forward(
283331
hidden_states = hidden_states.to(query.dtype)
284332

285333
# linear proj
286-
hidden_states = attn.to_out[0](hidden_states, *args)
334+
hidden_states = attn.to_out[0](hidden_states)
287335
# dropout
288336
hidden_states = attn.to_out[1](hidden_states)
289337

@@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
410458
add = ""
411459
if KCOMM in prompt:
412460
add, prompt = prompt.split(KCOMM)
413-
add = add + " "
414-
prompts = prompt.split(KBRK)
415-
out_p.append([add + p for p in prompts])
461+
add = add.strip() + " "
462+
prompts = [p.strip() for p in prompt.split(KBRK)]
463+
out_p.append([add + p for i, p in enumerate(prompts)])
416464
out = [None] * batch * len(out_p[0]) * len(out_p)
417465
for p, prs in enumerate(out_p): # inputs prompts
418466
for r, pr in enumerate(prs): # prompts for regions
@@ -449,7 +497,6 @@ def startend(cells, array):
449497
add = []
450498
startend(add, inratios[1:])
451499
icells.append(add)
452-
453500
return ocells, icells, sum(len(cell) for cell in icells)
454501

455502

0 commit comments

Comments
 (0)