@@ -160,7 +160,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
160160        self .norm2  =  FP32LayerNorm (dim , elementwise_affine = False , bias = False )
161161        self .ff  =  AuraFlowFeedForward (dim , dim  *  4 )
162162
163-     def  forward (self , hidden_states : torch .FloatTensor , temb : torch .FloatTensor , attention_kwargs : Optional [Dict [str , Any ]] =  None ):
163+     def  forward (
164+         self ,
165+         hidden_states : torch .FloatTensor ,
166+         temb : torch .FloatTensor ,
167+         attention_kwargs : Optional [Dict [str , Any ]] =  None ,
168+     ):
164169        residual  =  hidden_states 
165170        attention_kwargs  =  attention_kwargs  or  {}
166171
@@ -224,7 +229,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
224229        self .ff_context  =  AuraFlowFeedForward (dim , dim  *  4 )
225230
226231    def  forward (
227-         self , hidden_states : torch .FloatTensor , encoder_hidden_states : torch .FloatTensor , temb : torch .FloatTensor , attention_kwargs : Optional [Dict [str , Any ]] =  None ,
232+         self ,
233+         hidden_states : torch .FloatTensor ,
234+         encoder_hidden_states : torch .FloatTensor ,
235+         temb : torch .FloatTensor ,
236+         attention_kwargs : Optional [Dict [str , Any ]] =  None ,
228237    ):
229238        residual  =  hidden_states 
230239        residual_context  =  encoder_hidden_states 
@@ -238,7 +247,9 @@ def forward(
238247
239248        # Attention. 
240249        attn_output , context_attn_output  =  self .attn (
241-             hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states , ** attention_kwargs ,
250+             hidden_states = norm_hidden_states ,
251+             encoder_hidden_states = norm_encoder_hidden_states ,
252+             ** attention_kwargs ,
242253        )
243254
244255        # Process attention outputs for the `hidden_states`. 
@@ -492,7 +503,10 @@ def forward(
492503
493504            else :
494505                encoder_hidden_states , hidden_states  =  block (
495-                     hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states , temb = temb , attention_kwargs = attention_kwargs ,
506+                     hidden_states = hidden_states ,
507+                     encoder_hidden_states = encoder_hidden_states ,
508+                     temb = temb ,
509+                     attention_kwargs = attention_kwargs ,
496510                )
497511
498512        # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text) 
@@ -509,7 +523,9 @@ def forward(
509523                    )
510524
511525                else :
512-                     combined_hidden_states  =  block (hidden_states = combined_hidden_states , temb = temb , attention_kwargs = attention_kwargs )
526+                     combined_hidden_states  =  block (
527+                         hidden_states = combined_hidden_states , temb = temb , attention_kwargs = attention_kwargs 
528+                     )
513529
514530            hidden_states  =  combined_hidden_states [:, encoder_seq_len :]
515531
0 commit comments