Skip to content

Commit 1ab57b6

Browse files
sayakpaulyiyixuxu
andcommitted
[LoRA] pop the LoRA scale so that it doesn't get propagated to the weeds (#7338)
* pop scale from the top-level unet instead of getting it. * improve readability. * Apply suggestions from code review Co-authored-by: YiYi Xu <[email protected]> * fix a little bit. --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent cfa7c0a commit 1ab57b6

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,25 +1081,15 @@ def forward(
10811081
A tuple of tensors that if specified are added to the residuals of down unet blocks.
10821082
mid_block_additional_residual: (`torch.Tensor`, *optional*):
10831083
A tensor that if specified is added to the residual of the middle unet block.
1084+
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1085+
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
10841086
encoder_attention_mask (`torch.Tensor`):
10851087
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
10861088
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
10871089
which adds large negative values to the attention scores corresponding to "discard" tokens.
10881090
return_dict (`bool`, *optional*, defaults to `True`):
10891091
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
10901092
tuple.
1091-
cross_attention_kwargs (`dict`, *optional*):
1092-
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
1093-
added_cond_kwargs: (`dict`, *optional*):
1094-
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
1095-
are passed along to the UNet blocks.
1096-
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1097-
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
1098-
example from ControlNet side model(s)
1099-
mid_block_additional_residual (`torch.Tensor`, *optional*):
1100-
additional residual to be added to UNet mid block output, for example from ControlNet side model
1101-
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1102-
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
11031093
11041094
Returns:
11051095
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
@@ -1185,7 +1175,13 @@ def forward(
11851175
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
11861176

11871177
# 3. down
1188-
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1178+
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1179+
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1180+
if cross_attention_kwargs is not None:
1181+
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1182+
else:
1183+
lora_scale = 1.0
1184+
11891185
if USE_PEFT_BACKEND:
11901186
# weight the lora layers by setting `lora_scale` for each PEFT layer
11911187
scale_lora_layers(self, lora_scale)

0 commit comments

Comments
 (0)