@@ -1081,25 +1081,15 @@ def forward(
1081
1081
A tuple of tensors that if specified are added to the residuals of down unet blocks.
1082
1082
mid_block_additional_residual: (`torch.Tensor`, *optional*):
1083
1083
A tensor that if specified is added to the residual of the middle unet block.
1084
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1085
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1084
1086
encoder_attention_mask (`torch.Tensor`):
1085
1087
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1086
1088
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1087
1089
which adds large negative values to the attention scores corresponding to "discard" tokens.
1088
1090
return_dict (`bool`, *optional*, defaults to `True`):
1089
1091
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1090
1092
tuple.
1091
- cross_attention_kwargs (`dict`, *optional*):
1092
- A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
1093
- added_cond_kwargs: (`dict`, *optional*):
1094
- A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
1095
- are passed along to the UNet blocks.
1096
- down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1097
- additional residuals to be added to UNet long skip connections from down blocks to up blocks for
1098
- example from ControlNet side model(s)
1099
- mid_block_additional_residual (`torch.Tensor`, *optional*):
1100
- additional residual to be added to UNet mid block output, for example from ControlNet side model
1101
- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1102
- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1103
1093
1104
1094
Returns:
1105
1095
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
@@ -1185,7 +1175,13 @@ def forward(
1185
1175
cross_attention_kwargs ["gligen" ] = {"objs" : self .position_net (** gligen_args )}
1186
1176
1187
1177
# 3. down
1188
- lora_scale = cross_attention_kwargs .get ("scale" , 1.0 ) if cross_attention_kwargs is not None else 1.0
1178
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1179
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1180
+ if cross_attention_kwargs is not None :
1181
+ lora_scale = cross_attention_kwargs .pop ("scale" , 1.0 )
1182
+ else :
1183
+ lora_scale = 1.0
1184
+
1189
1185
if USE_PEFT_BACKEND :
1190
1186
# weight the lora layers by setting `lora_scale` for each PEFT layer
1191
1187
scale_lora_layers (self , lora_scale )
0 commit comments