2222from ...configuration_utils import ConfigMixin , register_to_config
2323from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
2424from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
25- from ..attention import JointTransformerBlock
25+ from ..attention import AttentionMixin , JointTransformerBlock
2626from ..attention_processor import Attention , AttentionProcessor , FusedJointAttnProcessor2_0
2727from ..embeddings import CombinedTimestepTextProjEmbeddings , PatchEmbed
2828from ..modeling_outputs import Transformer2DModelOutput
@@ -39,7 +39,7 @@ class SD3ControlNetOutput(BaseOutput):
3939 controlnet_block_samples : Tuple [torch .Tensor ]
4040
4141
42- class SD3ControlNetModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin ):
42+ class SD3ControlNetModel (ModelMixin , AttentionMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin ):
4343 r"""
4444 ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
4545
@@ -204,31 +204,6 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int
204204 for module in self .children ():
205205 fn_recursive_feed_forward (module , chunk_size , dim )
206206
207- @property
208- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
209- def attn_processors (self ) -> Dict [str , AttentionProcessor ]:
210- r"""
211- Returns:
212- `dict` of attention processors: A dictionary containing all attention processors used in the model with
213- indexed by its weight name.
214- """
215- # set recursively
216- processors = {}
217-
218- def fn_recursive_add_processors (name : str , module : torch .nn .Module , processors : Dict [str , AttentionProcessor ]):
219- if hasattr (module , "get_processor" ):
220- processors [f"{ name } .processor" ] = module .get_processor ()
221-
222- for sub_name , child in module .named_children ():
223- fn_recursive_add_processors (f"{ name } .{ sub_name } " , child , processors )
224-
225- return processors
226-
227- for name , module in self .named_children ():
228- fn_recursive_add_processors (name , module , processors )
229-
230- return processors
231-
232207 # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
233208 def set_attn_processor (self , processor : Union [AttentionProcessor , Dict [str , AttentionProcessor ]]):
234209 r"""
0 commit comments