|
25 | 25 | from ...loaders import PeftAdapterMixin |
26 | 26 | from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers |
27 | 27 | from ...utils.torch_utils import maybe_allow_in_graph |
28 | | -from ..attention import FeedForward |
| 28 | +from ..attention import AttentionMixin, FeedForward |
29 | 29 | from ..attention_dispatch import dispatch_attention_fn |
30 | | -from ..attention_processor import Attention, AttentionProcessor |
| 30 | +from ..attention_processor import Attention |
31 | 31 | from ..cache_utils import CacheMixin |
32 | 32 | from ..embeddings import ( |
33 | 33 | CombinedTimestepTextProjEmbeddings, |
@@ -616,7 +616,9 @@ def forward( |
616 | 616 | return hidden_states, encoder_hidden_states |
617 | 617 |
|
618 | 618 |
|
619 | | -class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): |
| 619 | +class HunyuanImageTransformer2DModel( |
| 620 | + ModelMixin, ConfigMixin, AttentionMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin |
| 621 | +): |
620 | 622 | r""" |
621 | 623 | The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1). |
622 | 624 |
|
@@ -667,10 +669,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, |
667 | 669 | "HunyuanImagePatchEmbed", |
668 | 670 | "HunyuanImageTokenRefiner", |
669 | 671 | ] |
670 | | - _repeated_blocks = [ |
671 | | - "HunyuanImageTransformerBlock", |
672 | | - "HunyuanImageSingleTransformerBlock", |
673 | | - ] |
| 672 | + _repeated_blocks = ["HunyuanImageTransformerBlock", "HunyuanImageSingleTransformerBlock"] |
674 | 673 |
|
675 | 674 | @register_to_config |
676 | 675 | def __init__( |
@@ -743,66 +742,6 @@ def __init__( |
743 | 742 |
|
744 | 743 | self.gradient_checkpointing = False |
745 | 744 |
|
746 | | - @property |
747 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors |
748 | | - def attn_processors(self) -> Dict[str, AttentionProcessor]: |
749 | | - r""" |
750 | | - Returns: |
751 | | - `dict` of attention processors: A dictionary containing all attention processors used in the model with |
752 | | - indexed by its weight name. |
753 | | - """ |
754 | | - # set recursively |
755 | | - processors = {} |
756 | | - |
757 | | - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
758 | | - if hasattr(module, "get_processor"): |
759 | | - processors[f"{name}.processor"] = module.get_processor() |
760 | | - |
761 | | - for sub_name, child in module.named_children(): |
762 | | - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
763 | | - |
764 | | - return processors |
765 | | - |
766 | | - for name, module in self.named_children(): |
767 | | - fn_recursive_add_processors(name, module, processors) |
768 | | - |
769 | | - return processors |
770 | | - |
771 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor |
772 | | - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
773 | | - r""" |
774 | | - Sets the attention processor to use to compute attention. |
775 | | -
|
776 | | - Parameters: |
777 | | - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
778 | | - The instantiated processor class or a dictionary of processor classes that will be set as the processor |
779 | | - for **all** `Attention` layers. |
780 | | -
|
781 | | - If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
782 | | - processor. This is strongly recommended when setting trainable attention processors. |
783 | | -
|
784 | | - """ |
785 | | - count = len(self.attn_processors.keys()) |
786 | | - |
787 | | - if isinstance(processor, dict) and len(processor) != count: |
788 | | - raise ValueError( |
789 | | - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
790 | | - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
791 | | - ) |
792 | | - |
793 | | - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
794 | | - if hasattr(module, "set_processor"): |
795 | | - if not isinstance(processor, dict): |
796 | | - module.set_processor(processor) |
797 | | - else: |
798 | | - module.set_processor(processor.pop(f"{name}.processor")) |
799 | | - |
800 | | - for sub_name, child in module.named_children(): |
801 | | - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
802 | | - |
803 | | - for name, module in self.named_children(): |
804 | | - fn_recursive_attn_processor(name, module, processor) |
805 | | - |
806 | 745 | def forward( |
807 | 746 | self, |
808 | 747 | hidden_states: torch.Tensor, |
|
0 commit comments