@@ -452,7 +452,7 @@ def forward(
452452 hidden_states : torch .FloatTensor ,
453453 encoder_hidden_states : torch .FloatTensor = None ,
454454 timestep : torch .LongTensor = None ,
455- attention_kwargs : Optional [Dict [str , Any ]] = None ,
455+ joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
456456 return_dict : bool = True ,
457457 ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
458458 height , width = hidden_states .shape [- 2 :]
@@ -465,18 +465,18 @@ def forward(
465465 encoder_hidden_states = torch .cat (
466466 [self .register_tokens .repeat (encoder_hidden_states .size (0 ), 1 , 1 ), encoder_hidden_states ], dim = 1
467467 )
468- if attention_kwargs is not None :
469- attention_kwargs = attention_kwargs .copy ()
470- lora_scale = attention_kwargs .pop ("scale" , 1.0 )
468+ if joint_attention_kwargs is not None :
469+ joint_attention_kwargs = joint_attention_kwargs .copy ()
470+ lora_scale = joint_attention_kwargs .pop ("scale" , 1.0 )
471471 else :
472472 lora_scale = 1.0
473473 if USE_PEFT_BACKEND :
474474 # weight the lora layers by setting `lora_scale` for each PEFT layer
475475 scale_lora_layers (self , lora_scale )
476476 else :
477- if attention_kwargs is not None and attention_kwargs .get ("scale" , None ) is not None :
477+ if joint_attention_kwargs is not None and joint_attention_kwargs .get ("scale" , None ) is not None :
478478 logger .warning (
479- "Passing `scale` via `attention_kwargs ` when not using the PEFT backend is ineffective."
479+ "Passing `scale` via `joint_attention_kwargs ` when not using the PEFT backend is ineffective."
480480 )
481481 # MMDiT blocks.
482482 for index_block , block in enumerate (self .joint_transformer_blocks ):
0 commit comments