Skip to content

Commit bb5ee5d

Browse files
Warlord-Khameerabbasi
authored andcommitted
Change attention_kwargs->joint_attention_kwargs
1 parent da1c686 commit bb5ee5d

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2125,6 +2125,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
21252125

21262126
@classmethod
21272127
@validate_hf_hub_args
2128+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
21282129
def lora_state_dict(
21292130
cls,
21302131
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -2429,7 +2430,6 @@ def fuse_lora(
24292430
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
24302431
)
24312432

2432-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.unfuse_lora
24332433
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
24342434
r"""
24352435
Reverses the effect of

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def forward(
452452
hidden_states: torch.FloatTensor,
453453
encoder_hidden_states: torch.FloatTensor = None,
454454
timestep: torch.LongTensor = None,
455-
attention_kwargs: Optional[Dict[str, Any]] = None,
455+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
456456
return_dict: bool = True,
457457
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
458458
height, width = hidden_states.shape[-2:]
@@ -465,18 +465,18 @@ def forward(
465465
encoder_hidden_states = torch.cat(
466466
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
467467
)
468-
if attention_kwargs is not None:
469-
attention_kwargs = attention_kwargs.copy()
470-
lora_scale = attention_kwargs.pop("scale", 1.0)
468+
if joint_attention_kwargs is not None:
469+
joint_attention_kwargs = joint_attention_kwargs.copy()
470+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
471471
else:
472472
lora_scale = 1.0
473473
if USE_PEFT_BACKEND:
474474
# weight the lora layers by setting `lora_scale` for each PEFT layer
475475
scale_lora_layers(self, lora_scale)
476476
else:
477-
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
477+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
478478
logger.warning(
479-
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
479+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
480480
)
481481
# MMDiT blocks.
482482
for index_block, block in enumerate(self.joint_transformer_blocks):

0 commit comments

Comments
 (0)