|
24 | 24 | from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers |
25 | 25 | from ...utils.import_utils import is_torch_npu_available |
26 | 26 | from ...utils.torch_utils import maybe_allow_in_graph |
27 | | -from ..attention import FeedForward |
28 | | -from ..attention_processor import ( |
29 | | - Attention, |
30 | | - AttentionProcessor, |
31 | | - FluxAttnProcessor2_0, |
32 | | - FluxAttnProcessor2_0_NPU, |
33 | | - FusedFluxAttnProcessor2_0, |
34 | | -) |
| 27 | +from ..attention import AttentionMixin, FeedForward |
35 | 28 | from ..cache_utils import CacheMixin |
36 | 29 | from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding |
37 | 30 | from ..modeling_outputs import Transformer2DModelOutput |
38 | 31 | from ..modeling_utils import ModelMixin |
39 | 32 | from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm |
| 33 | +from .transformer_flux import FluxAttention, FluxAttnProcessor |
40 | 34 |
|
41 | 35 |
|
42 | 36 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
@@ -223,18 +217,19 @@ def __init__( |
223 | 217 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) |
224 | 218 |
|
225 | 219 | if is_torch_npu_available(): |
| 220 | + from ..attention_processor import FluxAttnProcessor2_0_NPU |
| 221 | + |
226 | 222 | deprecation_message = ( |
227 | 223 | "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " |
228 | 224 | "should be set explicitly using the `set_attn_processor` method." |
229 | 225 | ) |
230 | 226 | deprecate("npu_processor", "0.34.0", deprecation_message) |
231 | 227 | processor = FluxAttnProcessor2_0_NPU() |
232 | 228 | else: |
233 | | - processor = FluxAttnProcessor2_0() |
| 229 | + processor = FluxAttnProcessor() |
234 | 230 |
|
235 | | - self.attn = Attention( |
| 231 | + self.attn = FluxAttention( |
236 | 232 | query_dim=dim, |
237 | | - cross_attention_dim=None, |
238 | 233 | dim_head=attention_head_dim, |
239 | 234 | heads=num_attention_heads, |
240 | 235 | out_dim=dim, |
@@ -292,16 +287,15 @@ def __init__( |
292 | 287 | self.norm1 = ChromaAdaLayerNormZeroPruned(dim) |
293 | 288 | self.norm1_context = ChromaAdaLayerNormZeroPruned(dim) |
294 | 289 |
|
295 | | - self.attn = Attention( |
| 290 | + self.attn = FluxAttention( |
296 | 291 | query_dim=dim, |
297 | | - cross_attention_dim=None, |
298 | 292 | added_kv_proj_dim=dim, |
299 | 293 | dim_head=attention_head_dim, |
300 | 294 | heads=num_attention_heads, |
301 | 295 | out_dim=dim, |
302 | 296 | context_pre_only=False, |
303 | 297 | bias=True, |
304 | | - processor=FluxAttnProcessor2_0(), |
| 298 | + processor=FluxAttnProcessor(), |
305 | 299 | qk_norm=qk_norm, |
306 | 300 | eps=eps, |
307 | 301 | ) |
@@ -376,7 +370,13 @@ def forward( |
376 | 370 |
|
377 | 371 |
|
378 | 372 | class ChromaTransformer2DModel( |
379 | | - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin |
| 373 | + ModelMixin, |
| 374 | + ConfigMixin, |
| 375 | + PeftAdapterMixin, |
| 376 | + FromOriginalModelMixin, |
| 377 | + FluxTransformer2DLoadersMixin, |
| 378 | + CacheMixin, |
| 379 | + AttentionMixin, |
380 | 380 | ): |
381 | 381 | """ |
382 | 382 | The Transformer model introduced in Flux, modified for Chroma. |
@@ -475,106 +475,6 @@ def __init__( |
475 | 475 |
|
476 | 476 | self.gradient_checkpointing = False |
477 | 477 |
|
478 | | - @property |
479 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors |
480 | | - def attn_processors(self) -> Dict[str, AttentionProcessor]: |
481 | | - r""" |
482 | | - Returns: |
483 | | - `dict` of attention processors: A dictionary containing all attention processors used in the model with |
484 | | - indexed by its weight name. |
485 | | - """ |
486 | | - # set recursively |
487 | | - processors = {} |
488 | | - |
489 | | - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
490 | | - if hasattr(module, "get_processor"): |
491 | | - processors[f"{name}.processor"] = module.get_processor() |
492 | | - |
493 | | - for sub_name, child in module.named_children(): |
494 | | - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
495 | | - |
496 | | - return processors |
497 | | - |
498 | | - for name, module in self.named_children(): |
499 | | - fn_recursive_add_processors(name, module, processors) |
500 | | - |
501 | | - return processors |
502 | | - |
503 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor |
504 | | - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
505 | | - r""" |
506 | | - Sets the attention processor to use to compute attention. |
507 | | -
|
508 | | - Parameters: |
509 | | - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
510 | | - The instantiated processor class or a dictionary of processor classes that will be set as the processor |
511 | | - for **all** `Attention` layers. |
512 | | -
|
513 | | - If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
514 | | - processor. This is strongly recommended when setting trainable attention processors. |
515 | | -
|
516 | | - """ |
517 | | - count = len(self.attn_processors.keys()) |
518 | | - |
519 | | - if isinstance(processor, dict) and len(processor) != count: |
520 | | - raise ValueError( |
521 | | - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
522 | | - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
523 | | - ) |
524 | | - |
525 | | - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
526 | | - if hasattr(module, "set_processor"): |
527 | | - if not isinstance(processor, dict): |
528 | | - module.set_processor(processor) |
529 | | - else: |
530 | | - module.set_processor(processor.pop(f"{name}.processor")) |
531 | | - |
532 | | - for sub_name, child in module.named_children(): |
533 | | - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
534 | | - |
535 | | - for name, module in self.named_children(): |
536 | | - fn_recursive_attn_processor(name, module, processor) |
537 | | - |
538 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 |
539 | | - def fuse_qkv_projections(self): |
540 | | - """ |
541 | | - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
542 | | - are fused. For cross-attention modules, key and value projection matrices are fused. |
543 | | -
|
544 | | - <Tip warning={true}> |
545 | | -
|
546 | | - This API is 🧪 experimental. |
547 | | -
|
548 | | - </Tip> |
549 | | - """ |
550 | | - self.original_attn_processors = None |
551 | | - |
552 | | - for _, attn_processor in self.attn_processors.items(): |
553 | | - if "Added" in str(attn_processor.__class__.__name__): |
554 | | - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") |
555 | | - |
556 | | - self.original_attn_processors = self.attn_processors |
557 | | - |
558 | | - for module in self.modules(): |
559 | | - if isinstance(module, Attention): |
560 | | - module.fuse_projections(fuse=True) |
561 | | - |
562 | | - self.set_attn_processor(FusedFluxAttnProcessor2_0()) |
563 | | - |
564 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections |
565 | | - def unfuse_qkv_projections(self): |
566 | | - """Disables the fused QKV projection if enabled. |
567 | | -
|
568 | | - <Tip warning={true}> |
569 | | -
|
570 | | - This API is 🧪 experimental. |
571 | | -
|
572 | | - </Tip> |
573 | | -
|
574 | | - """ |
575 | | - if self.original_attn_processors is not None: |
576 | | - self.set_attn_processor(self.original_attn_processors) |
577 | | - |
578 | 478 | def forward( |
579 | 479 | self, |
580 | 480 | hidden_states: torch.Tensor, |
|
0 commit comments