diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 209aad93244e..ec9e0bfa0246 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -56,6 +56,10 @@ def __init__( out_channels: int = 16, pos_embed_max_size: int = 96, extra_conditioning_channels: int = 0, + dual_attention_layers: Tuple[ + int, ... + ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 + qk_norm: Optional[str] = None, ): super().__init__() default_out_channels = in_channels @@ -84,6 +88,8 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=self.config.attention_head_dim, context_pre_only=False, + qk_norm=qk_norm, + use_dual_attention=True if i in dual_attention_layers else False, ) for i in range(num_layers) ] @@ -243,7 +249,7 @@ def _set_gradient_checkpointing(self, module, value=False): @classmethod def from_transformer( - cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True + cls, transformer, num_layers=12, num_extra_conditioning_channels=0, load_weights_from_transformer=True ): config = transformer.config config["num_layers"] = num_layers or config.num_layers