Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/source/en/api/pipelines/controlnet_sd3.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
</Tip>

## StableDiffusion3ControlNetPipeline

[[autodoc]] StableDiffusion3ControlNetPipeline
- all
- __call__

## StableDiffusion3ControlNetInpaintingPipeline
[[autodoc]] pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet_inpainting.StableDiffusion3ControlNetInpaintingPipeline

[[autodoc]] StableDiffusion3ControlNetInpaintingPipeline
- all
- __call__

## StableDiffusion3PipelineOutput

[[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableDiffusion3ControlNetInpaintingPipeline,
StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
Expand Down
31 changes: 26 additions & 5 deletions src/diffusers/models/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,28 @@ class SD3ControlNetOutput(BaseOutput):


class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A ControlNet model based on the SD3 architecture.

Parameters:
sample_size (`int`, defaults to `128`):
The width of the latent images. This is fixed during training since it is used to learn a number of
position embeddings.
patch_size (`int`, defaults to `2`): Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `16`): The number of channels in the input.
num_layers (`int`, defaults to `18`): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, defaults to `64`): The number of channels in each head.
num_attention_heads (`int`, defaults to `18`): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, defaults to `4096`): Input dimension of `encoder_hidden_states` before projection.
caption_projection_dim (`int`, defaults to `1152`):
Output dimension when projecting the `encoder_hidden_states`.
pooled_projection_dim (`int`, defaults to `2048`): Output dimension when projecting the `pooled_projections`.
out_channels (`int`, *optional*, defaults to `16`): Number of output channels.
pos_embed_max_size (`int`, *optional*, defaults to `96`): Max size for positional embeddings.
extra_conditioning_channels (`int`, defaults to `0`):
Additional conditioning channels to use with different controlnet models.
"""

_supports_gradient_checkpointing = True

@register_to_config
Expand All @@ -53,13 +75,12 @@ def __init__(
joint_attention_dim: int = 4096,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
out_channels: Optional[int] = 16,
pos_embed_max_size: Optional[int] = 96,
extra_conditioning_channels: int = 0,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = PatchEmbed(
Expand All @@ -82,7 +103,7 @@ def __init__(
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
attention_head_dim=attention_head_dim,
context_pre_only=False,
)
for i in range(num_layers)
Expand Down
22 changes: 11 additions & 11 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,17 @@ class PatchEmbed(nn.Module):

def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
pos_embed_type="sincos",
pos_embed_max_size=None, # For SD3 cropping
height: int = 224,
width: int = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dim: int = 768,
layer_norm: bool = False,
flatten: bool = True,
bias: bool = True,
interpolation_scale: float = 1,
pos_embed_type: str = "sincos",
pos_embed_max_size: Optional[int] = None, # For SD3 cropping
):
super().__init__()

Expand Down
55 changes: 28 additions & 27 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,26 @@


class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
r"""
The Transformer model introduced in Stable Diffusion 3.

Reference: https://arxiv.org/abs/2403.03206

Parameters:
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
out_channels (`int`, defaults to 16): Number of output channels.

sample_size (`int`, defaults to `128`):
The width of the latent images. This is fixed during training since it is used to learn a number of
position embeddings.
patch_size (`int`, defaults to `2`): Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `16`): The number of channels in the input.
num_layers (`int`, defaults to `18`): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, defaults to `64`): The number of channels in each head.
num_attention_heads (`int`, defaults to `18`): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, defaults to `4096`): Input dimension of `encoder_hidden_states` before projection.
caption_projection_dim (`int`, defaults to `1152`):
Output dimension when projecting the `encoder_hidden_states`.
pooled_projection_dim (`int`, defaults to `2048`): Output dimension when projecting the `pooled_projections`.
out_channels (`int`, *optional*, defaults to `16`): Number of output channels.
pos_embed_max_size (`int`, *optional*, defaults to `96`): Max size for positional embeddings.
"""

_supports_gradient_checkpointing = True
Expand All @@ -67,43 +69,42 @@ def __init__(
joint_attention_dim: int = 4096,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
out_channels: int = 16,
out_channels: Optional[int] = 16,
pos_embed_max_size: int = 96,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = PatchEmbed(
height=self.config.sample_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think whether to use the config object inside models is more of a collective decision. We have it set up like this in a few places. Functionally, they are equivalent but we just need to commit to an approach. My personal preference is to use the variable directly as done here. cc: @yiyixuxu @sayakpaul

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah using a variable directly is what I have been following recently as well. During the forward() if the variable is accessible through self.config that takes priority, but if not (such as inner_dim) just assign then to self during __init__().

width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)

# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
context_pre_only=i == num_layers - 1,
)
for i in range(self.config.num_layers)
for i in range(num_layers)
]
)

self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels, bias=True)

self.gradient_checkpointing = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@
```py
>>> import torch
>>> from diffusers.utils import load_image, check_min_version
>>> from diffusers.pipelines import StableDiffusion3ControlNetInpaintingPipeline
>>> from diffusers.models.controlnet_sd3 import SD3ControlNetModel
>>> from diffusers import StableDiffusion3ControlNetInpaintingPipeline, SD3ControlNetModel

>>> controlnet = SD3ControlNetModel.from_pretrained(
... "alimama-creative/SD3-Controlnet-Inpainting", use_safetensors=True, extra_conditioning_channels=1
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_torch_and_transformers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusion3ControlNetInpaintingPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusion3ControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

Expand Down
30 changes: 16 additions & 14 deletions tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,44 @@ class StableDiffusion3ControlInpaintNetPipelineFastTests(unittest.TestCase, Pipe
def get_dummy_components(self):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
sample_size=8,
patch_size=1,
in_channels=8,
num_layers=4,
attention_head_dim=8,
num_attention_heads=4,
num_layers=2,
attention_head_dim=4,
num_attention_heads=2,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
caption_projection_dim=8,
pooled_projection_dim=16,
out_channels=8,
)

torch.manual_seed(0)
controlnet = SD3ControlNetModel(
sample_size=32,
sample_size=8,
patch_size=1,
in_channels=8,
num_layers=1,
attention_head_dim=8,
num_attention_heads=4,
attention_head_dim=4,
num_attention_heads=2,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
caption_projection_dim=8,
pooled_projection_dim=16,
out_channels=8,
extra_conditioning_channels=1,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
projection_dim=8,
)

torch.manual_seed(0)
Expand Down Expand Up @@ -163,6 +163,8 @@ def get_dummy_inputs(self, device, seed=0):

inputs = {
"prompt": "A painting of a squirrel eating a burger",
"height": 32,
"width": 32,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 7.0,
Expand All @@ -176,7 +178,7 @@ def get_dummy_inputs(self, device, seed=0):

def test_controlnet_inpaint_sd3(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusion3ControlNetInpaintingPipeline(**components)
sd_pipe = self.pipeline_class(**components)
sd_pipe = sd_pipe.to(torch_device, dtype=torch.float16)
sd_pipe.set_progress_bar_config(disable=None)

Expand Down
26 changes: 13 additions & 13 deletions tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,43 +60,43 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
def get_dummy_components(self):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
sample_size=8,
patch_size=1,
in_channels=8,
num_layers=4,
attention_head_dim=8,
num_attention_heads=4,
num_layers=2,
attention_head_dim=4,
num_attention_heads=2,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
caption_projection_dim=8,
pooled_projection_dim=16,
out_channels=8,
)

torch.manual_seed(0)
controlnet = SD3ControlNetModel(
sample_size=32,
sample_size=8,
patch_size=1,
in_channels=8,
num_layers=1,
attention_head_dim=8,
num_attention_heads=4,
attention_head_dim=4,
num_attention_heads=2,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
caption_projection_dim=8,
pooled_projection_dim=16,
out_channels=8,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
projection_dim=8,
)

torch.manual_seed(0)
Expand Down
Loading
Loading