@@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput):
4040
4141
4242class SD3ControlNetModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin ):
43+ r"""
44+ ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
45+
46+ Parameters:
47+ sample_size (`int`, defaults to `128`):
48+ The width/height of the latents. This is fixed during training since it is used to learn a number of
49+ position embeddings.
50+ patch_size (`int`, defaults to `2`):
51+ Patch size to turn the input data into small patches.
52+ in_channels (`int`, defaults to `16`):
53+ The number of latent channels in the input.
54+ num_layers (`int`, defaults to `18`):
55+ The number of layers of transformer blocks to use.
56+ attention_head_dim (`int`, defaults to `64`):
57+ The number of channels in each head.
58+ num_attention_heads (`int`, defaults to `18`):
59+ The number of heads to use for multi-head attention.
60+ joint_attention_dim (`int`, defaults to `4096`):
61+ The embedding dimension to use for joint text-image attention.
62+ caption_projection_dim (`int`, defaults to `1152`):
63+ The embedding dimension of caption embeddings.
64+ pooled_projection_dim (`int`, defaults to `2048`):
65+ The embedding dimension of pooled text projections.
66+ out_channels (`int`, defaults to `16`):
67+ The number of latent channels in the output.
68+ pos_embed_max_size (`int`, defaults to `96`):
69+ The maximum latent height/width of positional embeddings.
70+ extra_conditioning_channels (`int`, defaults to `0`):
71+ The number of extra channels to use for conditioning for patch embedding.
72+ dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
73+ The number of dual-stream transformer blocks to use.
74+ qk_norm (`str`, *optional*, defaults to `None`):
75+ The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
76+ pos_embed_type (`str`, defaults to `"sincos"`):
77+ The type of positional embedding to use. Choose between `"sincos"` and `None`.
78+ use_pos_embed (`bool`, defaults to `True`):
79+ Whether to use positional embeddings.
80+ force_zeros_for_pooled_projection (`bool`, defaults to `True`):
81+ Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
82+ config value of the ControlNet model.
83+ """
84+
4385 _supports_gradient_checkpointing = True
4486
4587 @register_to_config
@@ -93,7 +135,7 @@ def __init__(
93135 JointTransformerBlock (
94136 dim = self .inner_dim ,
95137 num_attention_heads = num_attention_heads ,
96- attention_head_dim = self . config . attention_head_dim ,
138+ attention_head_dim = attention_head_dim ,
97139 context_pre_only = False ,
98140 qk_norm = qk_norm ,
99141 use_dual_attention = True if i in dual_attention_layers else False ,
@@ -108,7 +150,7 @@ def __init__(
108150 SD3SingleTransformerBlock (
109151 dim = self .inner_dim ,
110152 num_attention_heads = num_attention_heads ,
111- attention_head_dim = self . config . attention_head_dim ,
153+ attention_head_dim = attention_head_dim ,
112154 )
113155 for _ in range (num_layers )
114156 ]
@@ -297,28 +339,28 @@ def from_transformer(
297339
298340 def forward (
299341 self ,
300- hidden_states : torch .FloatTensor ,
342+ hidden_states : torch .Tensor ,
301343 controlnet_cond : torch .Tensor ,
302344 conditioning_scale : float = 1.0 ,
303- encoder_hidden_states : torch .FloatTensor = None ,
304- pooled_projections : torch .FloatTensor = None ,
345+ encoder_hidden_states : torch .Tensor = None ,
346+ pooled_projections : torch .Tensor = None ,
305347 timestep : torch .LongTensor = None ,
306348 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
307349 return_dict : bool = True ,
308- ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
350+ ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
309351 """
310352 The [`SD3Transformer2DModel`] forward method.
311353
312354 Args:
313- hidden_states (`torch.FloatTensor ` of shape `(batch size, channel, height, width)`):
355+ hidden_states (`torch.Tensor ` of shape `(batch size, channel, height, width)`):
314356 Input `hidden_states`.
315357 controlnet_cond (`torch.Tensor`):
316358 The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
317359 conditioning_scale (`float`, defaults to `1.0`):
318360 The scale factor for ControlNet outputs.
319- encoder_hidden_states (`torch.FloatTensor ` of shape `(batch size, sequence_len, embed_dims)`):
361+ encoder_hidden_states (`torch.Tensor ` of shape `(batch size, sequence_len, embed_dims)`):
320362 Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
321- pooled_projections (`torch.FloatTensor ` of shape `(batch_size, projection_dim)`): Embeddings projected
363+ pooled_projections (`torch.Tensor ` of shape `(batch_size, projection_dim)`): Embeddings projected
322364 from the embeddings of input conditions.
323365 timestep ( `torch.LongTensor`):
324366 Used to indicate denoising step.
@@ -437,11 +479,11 @@ def __init__(self, controlnets):
437479
438480 def forward (
439481 self ,
440- hidden_states : torch .FloatTensor ,
482+ hidden_states : torch .Tensor ,
441483 controlnet_cond : List [torch .tensor ],
442484 conditioning_scale : List [float ],
443- pooled_projections : torch .FloatTensor ,
444- encoder_hidden_states : torch .FloatTensor = None ,
485+ pooled_projections : torch .Tensor ,
486+ encoder_hidden_states : torch .Tensor = None ,
445487 timestep : torch .LongTensor = None ,
446488 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
447489 return_dict : bool = True ,
0 commit comments