Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableDiffusion3ControlNetInpaintingPipeline,
StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
Expand Down
31 changes: 26 additions & 5 deletions src/diffusers/models/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,28 @@ class SD3ControlNetOutput(BaseOutput):


class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A ControlNet model based on the SD3 architecture.

Parameters:
sample_size (`int`, defaults to `128`):
The width of the latent images. This is fixed during training since it is used to learn a number of
position embeddings.
patch_size (`int`, defaults to `2`): Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `16`): The number of channels in the input.
num_layers (`int`, defaults to `18`): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, defaults to `64`): The number of channels in each head.
num_attention_heads (`int`, defaults to `18`): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, defaults to `4096`): Input dimension of `encoder_hidden_states` before projection.
caption_projection_dim (`int`, defaults to `1152`):
Output dimension when projecting the `encoder_hidden_states`.
pooled_projection_dim (`int`, defaults to `2048`): Output dimension when projecting the `pooled_projections`.
out_channels (`int`, *optional*, defaults to `16`): Number of output channels.
pos_embed_max_size (`int`, *optional*, defaults to `96`): Max size for positional embeddings.
extra_conditioning_channels (`int`, defaults to `0`):
Additional conditioning channels to use with different controlnet models.
"""

_supports_gradient_checkpointing = True

@register_to_config
Expand All @@ -53,13 +75,12 @@ def __init__(
joint_attention_dim: int = 4096,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
out_channels: Optional[int] = 16,
pos_embed_max_size: Optional[int] = 96,
extra_conditioning_channels: int = 0,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = PatchEmbed(
Expand All @@ -82,7 +103,7 @@ def __init__(
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
attention_head_dim=attention_head_dim,
context_pre_only=False,
)
for i in range(num_layers)
Expand Down
22 changes: 11 additions & 11 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,17 @@ class PatchEmbed(nn.Module):

def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
pos_embed_type="sincos",
pos_embed_max_size=None, # For SD3 cropping
height: int = 224,
width: int = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dim: int = 768,
layer_norm: bool = False,
flatten: bool = True,
bias: bool = True,
interpolation_scale: float = 1,
pos_embed_type: str = "sincos",
pos_embed_max_size: Optional[int] = None, # For SD3 cropping
):
super().__init__()

Expand Down
55 changes: 28 additions & 27 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,26 @@


class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
r"""
The Transformer model introduced in Stable Diffusion 3.

Reference: https://arxiv.org/abs/2403.03206

Parameters:
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
out_channels (`int`, defaults to 16): Number of output channels.

sample_size (`int`, defaults to `128`):
The width of the latent images. This is fixed during training since it is used to learn a number of
position embeddings.
patch_size (`int`, defaults to `2`): Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `16`): The number of channels in the input.
num_layers (`int`, defaults to `18`): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, defaults to `64`): The number of channels in each head.
num_attention_heads (`int`, defaults to `18`): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, defaults to `4096`): Input dimension of `encoder_hidden_states` before projection.
caption_projection_dim (`int`, defaults to `1152`):
Output dimension when projecting the `encoder_hidden_states`.
pooled_projection_dim (`int`, defaults to `2048`): Output dimension when projecting the `pooled_projections`.
out_channels (`int`, *optional*, defaults to `16`): Number of output channels.
pos_embed_max_size (`int`, *optional*, defaults to `96`): Max size for positional embeddings.
"""

_supports_gradient_checkpointing = True
Expand All @@ -67,43 +69,42 @@ def __init__(
joint_attention_dim: int = 4096,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
out_channels: int = 16,
out_channels: Optional[int] = 16,
pos_embed_max_size: int = 96,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = PatchEmbed(
height=self.config.sample_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think whether to use the config object inside models is more of a collective decision. We have it set up like this in a few places. Functionally, they are equivalent but we just need to commit to an approach. My personal preference is to use the variable directly as done here. cc: @yiyixuxu @sayakpaul

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah using a variable directly is what I have been following recently as well. During the forward() if the variable is accessible through self.config that takes priority, but if not (such as inner_dim) just assign then to self during __init__().

width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)

# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
context_pre_only=i == num_layers - 1,
)
for i in range(self.config.num_layers)
for i in range(num_layers)
]
)

self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels, bias=True)

self.gradient_checkpointing = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@
```py
>>> import torch
>>> from diffusers.utils import load_image, check_min_version
>>> from diffusers.pipelines import StableDiffusion3ControlNetInpaintingPipeline
>>> from diffusers.models.controlnet_sd3 import SD3ControlNetModel
>>> from diffusers import StableDiffusion3ControlNetInpaintingPipeline, SD3ControlNetModel

>>> controlnet = SD3ControlNetModel.from_pretrained(
... "alimama-creative/SD3-Controlnet-Inpainting", use_safetensors=True, extra_conditioning_channels=1
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_torch_and_transformers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusion3ControlNetInpaintingPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusion3ControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

Expand Down
30 changes: 16 additions & 14 deletions tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,44 @@ class StableDiffusion3ControlInpaintNetPipelineFastTests(unittest.TestCase, Pipe
def get_dummy_components(self):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
sample_size=8,
patch_size=1,
in_channels=8,
num_layers=4,
attention_head_dim=8,
num_attention_heads=4,
num_layers=2,
attention_head_dim=4,
num_attention_heads=2,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
caption_projection_dim=8,
pooled_projection_dim=16,
out_channels=8,
)

torch.manual_seed(0)
controlnet = SD3ControlNetModel(
sample_size=32,
sample_size=8,
patch_size=1,
in_channels=8,
num_layers=1,
attention_head_dim=8,
num_attention_heads=4,
attention_head_dim=4,
num_attention_heads=2,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
caption_projection_dim=8,
pooled_projection_dim=16,
out_channels=8,
extra_conditioning_channels=1,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
projection_dim=8,
)

torch.manual_seed(0)
Expand Down Expand Up @@ -163,6 +163,8 @@ def get_dummy_inputs(self, device, seed=0):

inputs = {
"prompt": "A painting of a squirrel eating a burger",
"height": 32,
"width": 32,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 7.0,
Expand All @@ -176,7 +178,7 @@ def get_dummy_inputs(self, device, seed=0):

def test_controlnet_inpaint_sd3(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusion3ControlNetInpaintingPipeline(**components)
sd_pipe = self.pipeline_class(**components)
sd_pipe = sd_pipe.to(torch_device, dtype=torch.float16)
sd_pipe.set_progress_bar_config(disable=None)

Expand Down
26 changes: 13 additions & 13 deletions tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,43 +60,43 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
def get_dummy_components(self):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
sample_size=8,
patch_size=1,
in_channels=8,
num_layers=4,
attention_head_dim=8,
num_attention_heads=4,
num_layers=2,
attention_head_dim=4,
num_attention_heads=2,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
caption_projection_dim=8,
pooled_projection_dim=16,
out_channels=8,
)

torch.manual_seed(0)
controlnet = SD3ControlNetModel(
sample_size=32,
sample_size=8,
patch_size=1,
in_channels=8,
num_layers=1,
attention_head_dim=8,
num_attention_heads=4,
attention_head_dim=4,
num_attention_heads=2,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
caption_projection_dim=8,
pooled_projection_dim=16,
out_channels=8,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
projection_dim=8,
)

torch.manual_seed(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,29 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
def get_dummy_components(self):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
sample_size=8,
patch_size=1,
in_channels=4,
num_layers=1,
attention_head_dim=8,
num_attention_heads=4,
caption_projection_dim=32,
attention_head_dim=4,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=64,
caption_projection_dim=8,
pooled_projection_dim=16,
out_channels=4,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
projection_dim=8,
)

torch.manual_seed(0)
Expand Down
Loading