Skip to content

Commit 0293703

Browse files
Warlord-Khameerabbasi
authored andcommitted
Add comments, remove qkv fusion
1 parent 49186b8 commit 0293703

File tree

4 files changed

+17
-13
lines changed

4 files changed

+17
-13
lines changed

src/diffusers/loaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def text_encoder_attn_modules(text_encoder):
8989
AmusedLoraLoaderMixin,
9090
CogVideoXLoraLoaderMixin,
9191
FluxLoraLoaderMixin,
92+
AuraFlowLoraLoaderMixin,
9293
LoraLoaderMixin,
9394
Mochi1LoraLoaderMixin,
9495
SD3LoraLoaderMixin,
95-
AuraFlowLoraLoaderMixin,
9696
StableDiffusionLoraLoaderMixin,
9797
StableDiffusionXLLoraLoaderMixin,
9898
)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1649,9 +1649,9 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
16491649

16501650
_lora_loadable_modules = ["transformer"]
16511651
transformer_name = TRANSFORMER_NAME
1652-
text_encoder_name = TEXT_ENCODER_NAME
16531652

16541653
@classmethod
1654+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
16551655
@validate_hf_hub_args
16561656
def lora_state_dict(
16571657
cls,
@@ -1742,6 +1742,7 @@ def lora_state_dict(
17421742

17431743
return state_dict
17441744

1745+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_weights
17451746
def load_lora_weights(
17461747
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
17471748
):
@@ -1788,6 +1789,7 @@ def load_lora_weights(
17881789

17891790

17901791
@classmethod
1792+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
17911793
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
17921794
"""
17931795
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1866,6 +1868,7 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
18661868
# Unsafe code />
18671869

18681870
@classmethod
1871+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.save_lora_weights
18691872
def save_lora_weights(
18701873
cls,
18711874
save_directory: Union[str, os.PathLike],
@@ -1913,6 +1916,7 @@ def save_lora_weights(
19131916
safe_serialization=safe_serialization,
19141917
)
19151918

1919+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.fuse_lora
19161920
def fuse_lora(
19171921
self,
19181922
components: List[str] = ["transformer"],
@@ -1956,6 +1960,7 @@ def fuse_lora(
19561960
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
19571961
)
19581962

1963+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.unfuse_lora
19591964
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
19601965
r"""
19611966
Reverses the effect of

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import torch.nn.functional as F
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
23-
from ...utils import is_torch_version, logging
23+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2425
from ...utils.torch_utils import maybe_allow_in_graph
2526
from ..attention_processor import (
2627
Attention,
@@ -32,8 +33,6 @@
3233
from ..modeling_outputs import Transformer2DModelOutput
3334
from ..modeling_utils import ModelMixin
3435
from ..normalization import AdaLayerNormZero, FP32LayerNorm
35-
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
36-
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
3736

3837

3938
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -344,8 +343,8 @@ def __init__(
344343

345344
self.gradient_checkpointing = False
346345

347-
@property
348346
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
347+
@property
349348
def attn_processors(self) -> Dict[str, AttentionProcessor]:
350349
r"""
351350
Returns:
@@ -453,7 +452,7 @@ def forward(
453452
hidden_states: torch.FloatTensor,
454453
encoder_hidden_states: torch.FloatTensor = None,
455454
timestep: torch.LongTensor = None,
456-
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
455+
attention_kwargs: Optional[Dict[str, Any]] = None,
457456
return_dict: bool = True,
458457
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
459458
height, width = hidden_states.shape[-2:]
@@ -466,18 +465,18 @@ def forward(
466465
encoder_hidden_states = torch.cat(
467466
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
468467
)
469-
if joint_attention_kwargs is not None:
470-
joint_attention_kwargs = joint_attention_kwargs.copy()
471-
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
468+
if attention_kwargs is not None:
469+
attention_kwargs = attention_kwargs.copy()
470+
lora_scale = attention_kwargs.pop("scale", 1.0)
472471
else:
473472
lora_scale = 1.0
474473
if USE_PEFT_BACKEND:
475474
# weight the lora layers by setting `lora_scale` for each PEFT layer
476475
scale_lora_layers(self, lora_scale)
477476
else:
478-
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
477+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
479478
logger.warning(
480-
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
479+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
481480
)
482481
# MMDiT blocks.
483482
for index_block, block in enumerate(self.joint_transformer_blocks):

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
from transformers import T5Tokenizer, UMT5EncoderModel
1919

2020
from ...image_processor import VaeImageProcessor
21+
from ...loaders import AuraFlowLoraLoaderMixin
2122
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
2223
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
2324
from ...schedulers import FlowMatchEulerDiscreteScheduler
2425
from ...utils import logging, replace_example_docstring
2526
from ...utils.torch_utils import randn_tensor
2627
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
27-
from ...loaders import AuraFlowLoraLoaderMixin
2828

2929

3030
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

0 commit comments

Comments
 (0)