Skip to content

Commit a73cb39

Browse files
committed
refactor chroma as well
1 parent d9c1683 commit a73cb39

File tree

1 file changed

+15
-115
lines changed

1 file changed

+15
-115
lines changed

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 15 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,13 @@
2424
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
2525
from ...utils.import_utils import is_torch_npu_available
2626
from ...utils.torch_utils import maybe_allow_in_graph
27-
from ..attention import FeedForward
28-
from ..attention_processor import (
29-
Attention,
30-
AttentionProcessor,
31-
FluxAttnProcessor2_0,
32-
FluxAttnProcessor2_0_NPU,
33-
FusedFluxAttnProcessor2_0,
34-
)
27+
from ..attention import AttentionMixin, FeedForward
3528
from ..cache_utils import CacheMixin
3629
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
3730
from ..modeling_outputs import Transformer2DModelOutput
3831
from ..modeling_utils import ModelMixin
3932
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
33+
from .transformer_flux import FluxAttention, FluxAttnProcessor
4034

4135

4236
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -223,18 +217,19 @@ def __init__(
223217
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
224218

225219
if is_torch_npu_available():
220+
from ..attention_processor import FluxAttnProcessor2_0_NPU
221+
226222
deprecation_message = (
227223
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
228224
"should be set explicitly using the `set_attn_processor` method."
229225
)
230226
deprecate("npu_processor", "0.34.0", deprecation_message)
231227
processor = FluxAttnProcessor2_0_NPU()
232228
else:
233-
processor = FluxAttnProcessor2_0()
229+
processor = FluxAttnProcessor()
234230

235-
self.attn = Attention(
231+
self.attn = FluxAttention(
236232
query_dim=dim,
237-
cross_attention_dim=None,
238233
dim_head=attention_head_dim,
239234
heads=num_attention_heads,
240235
out_dim=dim,
@@ -292,16 +287,15 @@ def __init__(
292287
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
293288
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
294289

295-
self.attn = Attention(
290+
self.attn = FluxAttention(
296291
query_dim=dim,
297-
cross_attention_dim=None,
298292
added_kv_proj_dim=dim,
299293
dim_head=attention_head_dim,
300294
heads=num_attention_heads,
301295
out_dim=dim,
302296
context_pre_only=False,
303297
bias=True,
304-
processor=FluxAttnProcessor2_0(),
298+
processor=FluxAttnProcessor(),
305299
qk_norm=qk_norm,
306300
eps=eps,
307301
)
@@ -376,7 +370,13 @@ def forward(
376370

377371

378372
class ChromaTransformer2DModel(
379-
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
373+
ModelMixin,
374+
ConfigMixin,
375+
PeftAdapterMixin,
376+
FromOriginalModelMixin,
377+
FluxTransformer2DLoadersMixin,
378+
CacheMixin,
379+
AttentionMixin,
380380
):
381381
"""
382382
The Transformer model introduced in Flux, modified for Chroma.
@@ -475,106 +475,6 @@ def __init__(
475475

476476
self.gradient_checkpointing = False
477477

478-
@property
479-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
480-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
481-
r"""
482-
Returns:
483-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
484-
indexed by its weight name.
485-
"""
486-
# set recursively
487-
processors = {}
488-
489-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
490-
if hasattr(module, "get_processor"):
491-
processors[f"{name}.processor"] = module.get_processor()
492-
493-
for sub_name, child in module.named_children():
494-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
495-
496-
return processors
497-
498-
for name, module in self.named_children():
499-
fn_recursive_add_processors(name, module, processors)
500-
501-
return processors
502-
503-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
504-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
505-
r"""
506-
Sets the attention processor to use to compute attention.
507-
508-
Parameters:
509-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
510-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
511-
for **all** `Attention` layers.
512-
513-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
514-
processor. This is strongly recommended when setting trainable attention processors.
515-
516-
"""
517-
count = len(self.attn_processors.keys())
518-
519-
if isinstance(processor, dict) and len(processor) != count:
520-
raise ValueError(
521-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
522-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
523-
)
524-
525-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
526-
if hasattr(module, "set_processor"):
527-
if not isinstance(processor, dict):
528-
module.set_processor(processor)
529-
else:
530-
module.set_processor(processor.pop(f"{name}.processor"))
531-
532-
for sub_name, child in module.named_children():
533-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
534-
535-
for name, module in self.named_children():
536-
fn_recursive_attn_processor(name, module, processor)
537-
538-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
539-
def fuse_qkv_projections(self):
540-
"""
541-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
542-
are fused. For cross-attention modules, key and value projection matrices are fused.
543-
544-
<Tip warning={true}>
545-
546-
This API is 🧪 experimental.
547-
548-
</Tip>
549-
"""
550-
self.original_attn_processors = None
551-
552-
for _, attn_processor in self.attn_processors.items():
553-
if "Added" in str(attn_processor.__class__.__name__):
554-
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
555-
556-
self.original_attn_processors = self.attn_processors
557-
558-
for module in self.modules():
559-
if isinstance(module, Attention):
560-
module.fuse_projections(fuse=True)
561-
562-
self.set_attn_processor(FusedFluxAttnProcessor2_0())
563-
564-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
565-
def unfuse_qkv_projections(self):
566-
"""Disables the fused QKV projection if enabled.
567-
568-
<Tip warning={true}>
569-
570-
This API is 🧪 experimental.
571-
572-
</Tip>
573-
574-
"""
575-
if self.original_attn_processors is not None:
576-
self.set_attn_processor(self.original_attn_processors)
577-
578478
def forward(
579479
self,
580480
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)