Skip to content

Commit 05d5788

Browse files
authored
Merge branch 'main' into flux_guidancecontrol_inpaint
2 parents 0643318 + 7db9463 commit 05d5788

17 files changed

+6263
-86
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@
252252
title: SD3ControlNetModel
253253
- local: api/models/controlnet_sparsectrl
254254
title: SparseControlNetModel
255+
- local: api/models/controlnet_union
256+
title: ControlNetUnionModel
255257
title: ControlNets
256258
- sections:
257259
- local: api/models/allegro_transformer3d
@@ -368,6 +370,8 @@
368370
title: ControlNet-XS
369371
- local: api/pipelines/controlnetxs_sdxl
370372
title: ControlNet-XS with Stable Diffusion XL
373+
- local: api/pipelines/controlnet_union
374+
title: ControlNetUnion
371375
- local: api/pipelines/dance_diffusion
372376
title: Dance Diffusion
373377
- local: api/pipelines/ddim
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
<!--Copyright 2024 The HuggingFace Team and The InstantX Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# ControlNetUnionModel
14+
15+
ControlNetUnionModel is an implementation of ControlNet for Stable Diffusion XL.
16+
17+
The ControlNet model was introduced in [ControlNetPlus](https://github.com/xinsir6/ControlNetPlus) by xinsir6. It supports multiple conditioning inputs without increasing computation.
18+
19+
*We design a new architecture that can support 10+ control types in condition text-to-image generation and can generate high resolution images visually comparable with midjourney. The network is based on the original ControlNet architecture, we propose two new modules to: 1 Extend the original ControlNet to support different image conditions using the same network parameter. 2 Support multiple conditions input without increasing computation offload, which is especially important for designers who want to edit image in detail, different conditions use the same condition encoder, without adding extra computations or parameters.*
20+
21+
## Loading
22+
23+
By default the [`ControlNetUnionModel`] should be loaded with [`~ModelMixin.from_pretrained`].
24+
25+
```py
26+
from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel
27+
28+
controlnet = ControlNetUnionModel.from_pretrained("xinsir/controlnet-union-sdxl-1.0")
29+
pipe = StableDiffusionXLControlNetUnionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet)
30+
```
31+
32+
## ControlNetUnionModel
33+
34+
[[autodoc]] ControlNetUnionModel
35+
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# ControlNetUnion
14+
15+
ControlNetUnionModel is an implementation of ControlNet for Stable Diffusion XL.
16+
17+
The ControlNet model was introduced in [ControlNetPlus](https://github.com/xinsir6/ControlNetPlus) by xinsir6. It supports multiple conditioning inputs without increasing computation.
18+
19+
*We design a new architecture that can support 10+ control types in condition text-to-image generation and can generate high resolution images visually comparable with midjourney. The network is based on the original ControlNet architecture, we propose two new modules to: 1 Extend the original ControlNet to support different image conditions using the same network parameter. 2 Support multiple conditions input without increasing computation offload, which is especially important for designers who want to edit image in detail, different conditions use the same condition encoder, without adding extra computations or parameters.*
20+
21+
22+
## StableDiffusionXLControlNetUnionPipeline
23+
[[autodoc]] StableDiffusionXLControlNetUnionPipeline
24+
- all
25+
- __call__
26+
27+
## StableDiffusionXLControlNetUnionImg2ImgPipeline
28+
[[autodoc]] StableDiffusionXLControlNetUnionImg2ImgPipeline
29+
- all
30+
- __call__
31+
32+
## StableDiffusionXLControlNetUnionInpaintPipeline
33+
[[autodoc]] StableDiffusionXLControlNetUnionInpaintPipeline
34+
- all
35+
- __call__

examples/dreambooth/train_dreambooth.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,16 +1300,17 @@ def compute_text_embeddings(prompt):
13001300
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
13011301
# This is discussed in Section 4.2 of the same paper.
13021302
snr = compute_snr(noise_scheduler, timesteps)
1303-
base_weight = (
1304-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1305-
)
13061303

13071304
if noise_scheduler.config.prediction_type == "v_prediction":
13081305
# Velocity objective needs to be floored to an SNR weight of one.
1309-
mse_loss_weights = base_weight + 1
1306+
divisor = snr + 1
13101307
else:
1311-
# Epsilon and sample both use the same loss weights.
1312-
mse_loss_weights = base_weight
1308+
divisor = snr
1309+
1310+
mse_loss_weights = (
1311+
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / divisor
1312+
)
1313+
13131314
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
13141315
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
13151316
loss = loss.mean()

src/diffusers/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
"CogView3PlusTransformer2DModel",
9393
"ConsistencyDecoderVAE",
9494
"ControlNetModel",
95+
"ControlNetUnionModel",
9596
"ControlNetXSAdapter",
9697
"DiTTransformer2DModel",
9798
"FluxControlNetModel",
@@ -379,6 +380,9 @@
379380
"StableDiffusionXLControlNetPAGImg2ImgPipeline",
380381
"StableDiffusionXLControlNetPAGPipeline",
381382
"StableDiffusionXLControlNetPipeline",
383+
"StableDiffusionXLControlNetUnionImg2ImgPipeline",
384+
"StableDiffusionXLControlNetUnionInpaintPipeline",
385+
"StableDiffusionXLControlNetUnionPipeline",
382386
"StableDiffusionXLControlNetXSPipeline",
383387
"StableDiffusionXLImg2ImgPipeline",
384388
"StableDiffusionXLInpaintPipeline",
@@ -587,6 +591,7 @@
587591
CogView3PlusTransformer2DModel,
588592
ConsistencyDecoderVAE,
589593
ControlNetModel,
594+
ControlNetUnionModel,
590595
ControlNetXSAdapter,
591596
DiTTransformer2DModel,
592597
FluxControlNetModel,
@@ -852,6 +857,9 @@
852857
StableDiffusionXLControlNetPAGImg2ImgPipeline,
853858
StableDiffusionXLControlNetPAGPipeline,
854859
StableDiffusionXLControlNetPipeline,
860+
StableDiffusionXLControlNetUnionImg2ImgPipeline,
861+
StableDiffusionXLControlNetUnionInpaintPipeline,
862+
StableDiffusionXLControlNetUnionPipeline,
855863
StableDiffusionXLControlNetXSPipeline,
856864
StableDiffusionXLImg2ImgPipeline,
857865
StableDiffusionXLInpaintPipeline,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
]
4646
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
4747
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
48+
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
4849
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
4950
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
5051
_import_structure["embeddings"] = ["ImageProjection"]
@@ -102,6 +103,7 @@
102103
)
103104
from .controlnets import (
104105
ControlNetModel,
106+
ControlNetUnionModel,
105107
ControlNetXSAdapter,
106108
FluxControlNetModel,
107109
FluxMultiControlNetModel,

src/diffusers/models/attention_processor.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,14 @@ def set_use_memory_efficient_attention_xformers(
358358
self.processor,
359359
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
360360
)
361+
is_joint_processor = hasattr(self, "processor") and isinstance(
362+
self.processor,
363+
(
364+
JointAttnProcessor2_0,
365+
XFormersJointAttnProcessor,
366+
),
367+
)
368+
361369
if use_memory_efficient_attention_xformers:
362370
if is_added_kv_processor and is_custom_diffusion:
363371
raise NotImplementedError(
@@ -420,6 +428,8 @@ def set_use_memory_efficient_attention_xformers(
420428
processor.to(
421429
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
422430
)
431+
elif is_joint_processor:
432+
processor = XFormersJointAttnProcessor(attention_op=attention_op)
423433
else:
424434
processor = XFormersAttnProcessor(attention_op=attention_op)
425435
else:
@@ -1685,6 +1695,91 @@ def __call__(
16851695
return hidden_states, encoder_hidden_states
16861696

16871697

1698+
class XFormersJointAttnProcessor:
1699+
r"""
1700+
Processor for implementing memory efficient attention using xFormers.
1701+
1702+
Args:
1703+
attention_op (`Callable`, *optional*, defaults to `None`):
1704+
The base
1705+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1706+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1707+
operator.
1708+
"""
1709+
1710+
def __init__(self, attention_op: Optional[Callable] = None):
1711+
self.attention_op = attention_op
1712+
1713+
def __call__(
1714+
self,
1715+
attn: Attention,
1716+
hidden_states: torch.FloatTensor,
1717+
encoder_hidden_states: torch.FloatTensor = None,
1718+
attention_mask: Optional[torch.FloatTensor] = None,
1719+
*args,
1720+
**kwargs,
1721+
) -> torch.FloatTensor:
1722+
residual = hidden_states
1723+
1724+
# `sample` projections.
1725+
query = attn.to_q(hidden_states)
1726+
key = attn.to_k(hidden_states)
1727+
value = attn.to_v(hidden_states)
1728+
1729+
query = attn.head_to_batch_dim(query).contiguous()
1730+
key = attn.head_to_batch_dim(key).contiguous()
1731+
value = attn.head_to_batch_dim(value).contiguous()
1732+
1733+
if attn.norm_q is not None:
1734+
query = attn.norm_q(query)
1735+
if attn.norm_k is not None:
1736+
key = attn.norm_k(key)
1737+
1738+
# `context` projections.
1739+
if encoder_hidden_states is not None:
1740+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1741+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1742+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1743+
1744+
encoder_hidden_states_query_proj = attn.head_to_batch_dim(encoder_hidden_states_query_proj).contiguous()
1745+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj).contiguous()
1746+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj).contiguous()
1747+
1748+
if attn.norm_added_q is not None:
1749+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1750+
if attn.norm_added_k is not None:
1751+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1752+
1753+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1754+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1755+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1756+
1757+
hidden_states = xformers.ops.memory_efficient_attention(
1758+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1759+
)
1760+
hidden_states = hidden_states.to(query.dtype)
1761+
hidden_states = attn.batch_to_head_dim(hidden_states)
1762+
1763+
if encoder_hidden_states is not None:
1764+
# Split the attention outputs.
1765+
hidden_states, encoder_hidden_states = (
1766+
hidden_states[:, : residual.shape[1]],
1767+
hidden_states[:, residual.shape[1] :],
1768+
)
1769+
if not attn.context_pre_only:
1770+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1771+
1772+
# linear proj
1773+
hidden_states = attn.to_out[0](hidden_states)
1774+
# dropout
1775+
hidden_states = attn.to_out[1](hidden_states)
1776+
1777+
if encoder_hidden_states is not None:
1778+
return hidden_states, encoder_hidden_states
1779+
else:
1780+
return hidden_states
1781+
1782+
16881783
class AllegroAttnProcessor2_0:
16891784
r"""
16901785
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is

src/diffusers/models/controlnets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SparseControlNetModel,
1616
SparseControlNetOutput,
1717
)
18+
from .controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax, ControlNetUnionModel
1819
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
1920
from .multicontrolnet import MultiControlNetModel
2021

0 commit comments

Comments
 (0)