Skip to content

Commit 8b03bce

Browse files
committed
up
1 parent 12ba3fc commit 8b03bce

File tree

1 file changed

+6
-67
lines changed

1 file changed

+6
-67
lines changed

src/diffusers/models/transformers/transformer_hunyuanimage.py

Lines changed: 6 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from ...loaders import PeftAdapterMixin
2626
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2727
from ...utils.torch_utils import maybe_allow_in_graph
28-
from ..attention import FeedForward
28+
from ..attention import AttentionMixin, FeedForward
2929
from ..attention_dispatch import dispatch_attention_fn
30-
from ..attention_processor import Attention, AttentionProcessor
30+
from ..attention_processor import Attention
3131
from ..cache_utils import CacheMixin
3232
from ..embeddings import (
3333
CombinedTimestepTextProjEmbeddings,
@@ -616,7 +616,9 @@ def forward(
616616
return hidden_states, encoder_hidden_states
617617

618618

619-
class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
619+
class HunyuanImageTransformer2DModel(
620+
ModelMixin, ConfigMixin, AttentionMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
621+
):
620622
r"""
621623
The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
622624
@@ -667,10 +669,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
667669
"HunyuanImagePatchEmbed",
668670
"HunyuanImageTokenRefiner",
669671
]
670-
_repeated_blocks = [
671-
"HunyuanImageTransformerBlock",
672-
"HunyuanImageSingleTransformerBlock",
673-
]
672+
_repeated_blocks = ["HunyuanImageTransformerBlock", "HunyuanImageSingleTransformerBlock"]
674673

675674
@register_to_config
676675
def __init__(
@@ -743,66 +742,6 @@ def __init__(
743742

744743
self.gradient_checkpointing = False
745744

746-
@property
747-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
748-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
749-
r"""
750-
Returns:
751-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
752-
indexed by its weight name.
753-
"""
754-
# set recursively
755-
processors = {}
756-
757-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
758-
if hasattr(module, "get_processor"):
759-
processors[f"{name}.processor"] = module.get_processor()
760-
761-
for sub_name, child in module.named_children():
762-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
763-
764-
return processors
765-
766-
for name, module in self.named_children():
767-
fn_recursive_add_processors(name, module, processors)
768-
769-
return processors
770-
771-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
772-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
773-
r"""
774-
Sets the attention processor to use to compute attention.
775-
776-
Parameters:
777-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
778-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
779-
for **all** `Attention` layers.
780-
781-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
782-
processor. This is strongly recommended when setting trainable attention processors.
783-
784-
"""
785-
count = len(self.attn_processors.keys())
786-
787-
if isinstance(processor, dict) and len(processor) != count:
788-
raise ValueError(
789-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
790-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
791-
)
792-
793-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
794-
if hasattr(module, "set_processor"):
795-
if not isinstance(processor, dict):
796-
module.set_processor(processor)
797-
else:
798-
module.set_processor(processor.pop(f"{name}.processor"))
799-
800-
for sub_name, child in module.named_children():
801-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
802-
803-
for name, module in self.named_children():
804-
fn_recursive_attn_processor(name, module, processor)
805-
806745
def forward(
807746
self,
808747
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)