11import math
22from typing import Any , Dict , List , Optional , Tuple
33
4- import einops
54import torch
65import torch .nn as nn
76import torch .nn .functional as F
@@ -756,21 +755,26 @@ def expand_timesteps(self, timesteps, batch_size, device):
756755
757756 def unpatchify (self , x : torch .Tensor , img_sizes : List [Tuple [int , int ]], is_training : bool ) -> List [torch .Tensor ]:
758757 if is_training :
759- x = einops .rearrange (
760- x , "B S (p1 p2 C) -> B C S (p1 p2)" , p1 = self .config .patch_size , p2 = self .config .patch_size
758+ B , S , F = x .shape
759+ C = F // (self .config .patch_size * self .config .patch_size )
760+ x = (
761+ x .reshape (B , S , self .config .patch_size , self .config .patch_size , C )
762+ .permute (0 , 4 , 1 , 2 , 3 )
763+ .reshape (B , C , S , self .config .patch_size * self .config .patch_size )
761764 )
762765 else :
763766 x_arr = []
767+ p1 = self .config .patch_size
768+ p2 = self .config .patch_size
764769 for i , img_size in enumerate (img_sizes ):
765770 pH , pW = img_size
766- x_arr .append (
767- einops .rearrange (
768- x [i , : pH * pW ].reshape (1 , pH , pW , - 1 ),
769- "B H W (p1 p2 C) -> B C (H p1) (W p2)" ,
770- p1 = self .config .patch_size ,
771- p2 = self .config .patch_size ,
772- )
773- )
771+ t = x [i , : pH * pW ].reshape (1 , pH , pW , - 1 )
772+ F_token = t .shape [- 1 ]
773+ C = F_token // (p1 * p2 )
774+ t = t .reshape (1 , pH , pW , p1 , p2 , C )
775+ t = t .permute (0 , 5 , 1 , 3 , 2 , 4 )
776+ t = t .reshape (1 , C , pH * p1 , pW * p2 )
777+ x_arr .append (t )
774778 x = torch .cat (x_arr , dim = 0 )
775779 return x
776780
@@ -789,12 +793,14 @@ def patchify(self, x, max_seq, img_sizes=None):
789793 if img_sizes is not None :
790794 for i , img_size in enumerate (img_sizes ):
791795 x_masks [i , 0 : img_size [0 ] * img_size [1 ]] = 1
792- x = einops .rearrange (x , "B C S p -> B S (p C)" , p = pz2 , C = C )
796+ B , C , S , _ = x .shape
797+ x = x .permute (0 , 2 , 3 , 1 ).reshape (B , S , pz2 * C )
793798 elif isinstance (x , torch .Tensor ):
794- pH , pW = x .shape [- 2 ] // self .config .patch_size , x .shape [- 1 ] // self .config .patch_size
795- x = einops .rearrange (
796- x , "B C (H p1) (W p2) -> B (H W) (p1 p2 C)" , p1 = self .config .patch_size , p2 = self .config .patch_size , C = C
797- )
799+ B , C , Hp1 , Wp2 = x .shape
800+ pH , pW = Hp1 // self .config .patch_size , Wp2 // self .config .patch_size
801+ x = x .reshape (B , C , pH , self .config .patch_size , pW , self .config .patch_size )
802+ x = x .permute (0 , 2 , 4 , 3 , 5 , 1 )
803+ x = x .reshape (B , pH * pW , self .config .patch_size * self .config .patch_size * C )
798804 img_sizes = [[pH , pW ]] * B
799805 x_masks = None
800806 else :
0 commit comments