1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15- from  functools  import  partial 
16- from  typing  import  Any , Dict , Optional , Union 
15+ from  typing  import  Dict , Optional , Union 
1716
1817import  torch 
1918from  torch  import  nn 
2524    Attention ,
2625    AttentionProcessor ,
2726    AttnProcessor2_0 ,
28-     SanaMultiscaleLinearAttention ,
2927    SanaLinearAttnProcessor2_0 ,
3028)
3129from  ..embeddings  import  PatchEmbed , PixArtAlphaTextProjection 
@@ -135,7 +133,7 @@ def __init__(
135133            mlp_ratio = mlp_ratio ,
136134        )
137135
138-         self .scale_shift_table  =  nn .Parameter (torch .randn (6 , dim ) /  dim   **   0.5 )
136+         self .scale_shift_table  =  nn .Parameter (torch .randn (6 , dim ) /  dim ** 0.5 )
139137
140138    def  forward (
141139        self ,
@@ -152,7 +150,7 @@ def forward(
152150        shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp  =  (
153151            self .scale_shift_table [None ] +  timestep .reshape (batch_size , 6 , - 1 )
154152        ).chunk (6 , dim = 1 )
155-          
153+ 
156154        # 2. Self Attention 
157155        norm_hidden_states  =  self .norm1 (hidden_states )
158156        norm_hidden_states  =  norm_hidden_states  *  (1  +  scale_msa ) +  shift_msa 
@@ -258,9 +256,7 @@ def __init__(
258256        )
259257
260258        # 2. Caption Embedding 
261-         self .caption_projection  =  PixArtAlphaTextProjection (
262-             in_features = caption_channels , hidden_size = inner_dim 
263-         )
259+         self .caption_projection  =  PixArtAlphaTextProjection (in_features = caption_channels , hidden_size = inner_dim )
264260        self .caption_norm  =  RMSNorm (inner_dim , eps = 1e-5 )
265261
266262        # 3. Transformer blocks 
@@ -285,7 +281,7 @@ def __init__(
285281
286282        # 4. Output blocks 
287283        self .scale_shift_table  =  nn .Parameter (torch .randn (2 , inner_dim ) /  inner_dim ** 0.5 )
288-          
284+ 
289285        self .norm_out  =  nn .LayerNorm (inner_dim , elementwise_affine = False , eps = 1e-6 )
290286        self .proj_out  =  nn .Linear (inner_dim , patch_size  *  patch_size  *  out_channels )
291287
@@ -401,12 +397,12 @@ def forward(
401397
402398        encoder_hidden_states  =  self .caption_projection (encoder_hidden_states )
403399        encoder_hidden_states  =  encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
404-          
400+ 
405401        encoder_hidden_states  =  self .caption_norm (encoder_hidden_states )
406402
407403        # 2. Transformer blocks 
408404        use_reentrant  =  is_torch_version ("<=" , "1.11.0" )
409-          
405+ 
410406        def  create_block_forward (block ):
411407            if  torch .is_grad_enabled () and  self .gradient_checkpointing :
412408                return  lambda  * inputs : torch .utils .checkpoint .checkpoint (
@@ -430,16 +426,23 @@ def create_block_forward(block):
430426            self .scale_shift_table [None ] +  embedded_timestep [:, None ].to (self .scale_shift_table .device )
431427        ).chunk (2 , dim = 1 )
432428        hidden_states  =  self .norm_out (hidden_states )
433-          
429+ 
434430        # 4. Modulation 
435431        hidden_states  =  hidden_states  *  (1  +  scale ) +  shift 
436432        hidden_states  =  self .proj_out (hidden_states )
437433
438434        # 5. Unpatchify 
439-         hidden_states  =  hidden_states .reshape (batch_size , post_patch_height , post_patch_width , self .config .patch_size , self .config .patch_size , - 1 )
435+         hidden_states  =  hidden_states .reshape (
436+             batch_size , post_patch_height , post_patch_width , self .config .patch_size , self .config .patch_size , - 1 
437+         )
440438        hidden_states  =  hidden_states .permute (0 , 5 , 1 , 3 , 2 , 4 )
441439        output  =  hidden_states .reshape (
442-             shape = (batch_size , - 1 , post_patch_height  *  self .config .patch_size , post_patch_width  *  self .config .patch_size )
440+             shape = (
441+                 batch_size ,
442+                 - 1 ,
443+                 post_patch_height  *  self .config .patch_size ,
444+                 post_patch_width  *  self .config .patch_size ,
445+             )
443446        )
444447
445448        if  not  return_dict :
0 commit comments