@@ -90,7 +90,7 @@ def __init__(
9090            nn .SiLU (),
9191            nn .Conv3d (out_dim , in_dim , (3 , stride , stride ), padding = (pad_t , pad_h , pad_h )),
9292        )
93-      
93+ 
9494    @staticmethod  
9595    def  _pad_temporal_dim (hidden_states : torch .Tensor ) ->  torch .Tensor :
9696        hidden_states  =  torch .cat ((hidden_states [:, :, 0 :1 ], hidden_states ), dim = 2 )
@@ -118,10 +118,10 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor:
118118
119119        hidden_states  =  self ._pad_temporal_dim (hidden_states )
120120        hidden_states  =  self .conv2 (hidden_states )
121-          
121+ 
122122        hidden_states  =  self ._pad_temporal_dim (hidden_states )
123123        hidden_states  =  self .conv3 (hidden_states )
124-          
124+ 
125125        hidden_states  =  self ._pad_temporal_dim (hidden_states )
126126        hidden_states  =  self .conv4 (hidden_states )
127127
@@ -200,7 +200,7 @@ def __init__(
200200
201201    def  forward (self , hidden_states : torch .Tensor ) ->  torch .Tensor :
202202        batch_size  =  hidden_states .shape [0 ]
203-          
203+ 
204204        hidden_states  =  hidden_states .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
205205
206206        for  resnet , temp_conv  in  zip (self .resnets , self .temp_convs ):
@@ -213,7 +213,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
213213        if  self .downsamplers  is  not None :
214214            for  downsampler  in  self .downsamplers :
215215                hidden_states  =  downsampler (hidden_states )
216-              
216+ 
217217        hidden_states  =  hidden_states .unflatten (0 , (batch_size , - 1 )).permute (0 , 2 , 1 , 3 , 4 )
218218        return  hidden_states 
219219
@@ -282,7 +282,7 @@ def __init__(
282282
283283    def  forward (self , hidden_states : torch .Tensor ) ->  torch .Tensor :
284284        batch_size  =  hidden_states .shape [0 ]
285-          
285+ 
286286        hidden_states  =  hidden_states .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
287287
288288        for  resnet , temp_conv  in  zip (self .resnets , self .temp_convs ):
@@ -295,7 +295,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
295295        if  self .upsamplers  is  not None :
296296            for  upsampler  in  self .upsamplers :
297297                hidden_states  =  upsampler (hidden_states )
298-              
298+ 
299299        hidden_states  =  hidden_states .unflatten (0 , (batch_size , - 1 )).permute (0 , 2 , 1 , 3 , 4 )
300300        return  hidden_states 
301301
@@ -399,7 +399,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
399399
400400        hidden_states  =  hidden_states .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
401401        hidden_states  =  self .resnets [0 ](hidden_states , temb = None )
402-          
402+ 
403403        hidden_states  =  self .temp_convs [0 ](hidden_states , batch_size = batch_size )
404404
405405        for  attn , resnet , temp_conv  in  zip (self .attentions , self .resnets [1 :], self .temp_convs [1 :]):
@@ -532,15 +532,15 @@ def custom_forward(*inputs):
532532        sample  =  sample .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
533533        sample  =  self .conv_norm_out (sample )
534534        sample  =  self .conv_act (sample )
535-          
535+ 
536536        sample  =  sample .unflatten (0 , (batch_size , - 1 )).permute (0 , 2 , 1 , 3 , 4 )
537537        residual  =  sample 
538538        sample  =  self .temp_conv_out (sample )
539539        sample  =  sample  +  residual 
540-          
540+ 
541541        sample  =  sample .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
542542        sample  =  self .conv_out (sample )
543-          
543+ 
544544        sample  =  sample .unflatten (0 , (batch_size , - 1 )).permute (0 , 2 , 1 , 3 , 4 )
545545        return  sample 
546546
@@ -674,15 +674,15 @@ def custom_forward(*inputs):
674674        sample  =  sample .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
675675        sample  =  self .conv_norm_out (sample )
676676        sample  =  self .conv_act (sample )
677-          
677+ 
678678        sample  =  sample .unflatten (0 , (batch_size , - 1 )).permute (0 , 2 , 1 , 3 , 4 )
679679        residual  =  sample 
680680        sample  =  self .temp_conv_out (sample )
681681        sample  =  sample  +  residual 
682682
683683        sample  =  sample .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
684684        sample  =  self .conv_out (sample )
685-          
685+ 
686686        sample  =  sample .unflatten (0 , (batch_size , - 1 )).permute (0 , 2 , 1 , 3 , 4 )
687687        return  sample 
688688
@@ -804,7 +804,7 @@ def __init__(
804804        chunk_len  =  24 
805805        t_over  =  8 
806806        tile_overlap  =  (120 , 80 )
807-          
807+ 
808808        self .latent_chunk_len  =  chunk_len  //  4 
809809        self .latent_t_over  =  t_over  //  4 
810810        self .kernel  =  (chunk_len , sample_size , sample_size )  # (24, 256, 256) 
@@ -817,7 +817,7 @@ def __init__(
817817    def  _set_gradient_checkpointing (self , module , value = False ):
818818        if  isinstance (module , (AllegroEncoder3D , AllegroDecoder3D )):
819819            module .gradient_checkpointing  =  value 
820-      
820+ 
821821    def  enable_tiling (
822822        self ,
823823        # tile_sample_min_height: Optional[int] = None, 
@@ -876,17 +876,19 @@ def disable_slicing(self) -> None:
876876        decoding in one step. 
877877        """ 
878878        self .use_slicing  =  False 
879-      
879+ 
880880    def  _encode (self , x : torch .Tensor ) ->  torch .Tensor :
881881        # TODO(aryan) 
882882        # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): 
883883        if  self .use_tiling :
884884            return  self .tiled_encode (x )
885-          
885+ 
886886        raise  NotImplementedError ("Encoding without tiling has not been implemented yet." )
887-      
887+ 
888888    @apply_forward_hook  
889-     def  encode (self , x : torch .Tensor , return_dict : bool  =  True ) ->  Union [AutoencoderKLOutput , Tuple [DiagonalGaussianDistribution ]]:
889+     def  encode (
890+         self , x : torch .Tensor , return_dict : bool  =  True 
891+     ) ->  Union [AutoencoderKLOutput , Tuple [DiagonalGaussianDistribution ]]:
890892        r""" 
891893        Encode a batch of videos into latents. 
892894
@@ -919,7 +921,7 @@ def _decode(self, z: torch.Tensor) -> torch.Tensor:
919921            return  self .tiled_decode (z )
920922
921923        raise  NotImplementedError ("Decoding without tiling has not been implemented yet." )
922-      
924+ 
923925    @apply_forward_hook  
924926    def  decode (self , z : torch .Tensor , return_dict : bool  =  True ) ->  Union [DecoderOutput , torch .Tensor ]:
925927        """ 
@@ -946,12 +948,10 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
946948            return  (decoded ,)
947949        return  DecoderOutput (sample = decoded )
948950
949-     def  tiled_encode (
950-         self , x : torch .Tensor 
951-     ) ->  torch .Tensor :
951+     def  tiled_encode (self , x : torch .Tensor ) ->  torch .Tensor :
952952        # TODO(aryan): parameterize this in enable_tiling 
953953        local_batch_size  =  1 
954-          
954+ 
955955        # TODO(aryan): rewrite to encode and tiled_encode 
956956        KERNEL  =  self .kernel 
957957        STRIDE  =  self .stride 
@@ -972,9 +972,7 @@ def tiled_encode(
972972            device = x .device ,
973973            dtype = x .dtype ,
974974        )
975-         vae_batch_input  =  torch .zeros (
976-             (LOCAL_BS , C , KERNEL [0 ], KERNEL [1 ], KERNEL [2 ]), device = x .device , dtype = x .dtype 
977-         )
975+         vae_batch_input  =  torch .zeros ((LOCAL_BS , C , KERNEL [0 ], KERNEL [1 ], KERNEL [2 ]), device = x .device , dtype = x .dtype )
978976
979977        for  i  in  range (out_n ):
980978            for  j  in  range (out_h ):
@@ -1002,9 +1000,7 @@ def tiled_encode(
10021000        ## flatten the batched out latent to videos and supress the overlapped parts 
10031001        B , C , N , H , W  =  x .shape 
10041002
1005-         out_video_cube  =  torch .zeros (
1006-             (B , OUT_C , N  //  4 , H  //  8 , W  //  8 ), device = x .device , dtype = x .dtype 
1007-         )
1003+         out_video_cube  =  torch .zeros ((B , OUT_C , N  //  4 , H  //  8 , W  //  8 ), device = x .device , dtype = x .dtype )
10081004        OUT_KERNEL  =  KERNEL [0 ] //  4 , KERNEL [1 ] //  8 , KERNEL [2 ] //  8 
10091005        OUT_STRIDE  =  STRIDE [0 ] //  4 , STRIDE [1 ] //  8 , STRIDE [2 ] //  8 
10101006        OVERLAP  =  OUT_KERNEL [0 ] -  OUT_STRIDE [0 ], OUT_KERNEL [1 ] -  OUT_STRIDE [1 ], OUT_KERNEL [2 ] -  OUT_STRIDE [2 ]
@@ -1030,9 +1026,7 @@ def tiled_encode(
10301026
10311027        return  out_video_cube 
10321028
1033-     def  tiled_decode (
1034-         self , z : torch .Tensor 
1035-     ) ->  torch .Tensor :
1029+     def  tiled_decode (self , z : torch .Tensor ) ->  torch .Tensor :
10361030        # TODO(aryan): parameterize this in enable_tiling 
10371031        local_batch_size  =  1 
10381032
@@ -1092,9 +1086,7 @@ def tiled_decode(
10921086                    num  +=  1 
10931087        B , C , N , H , W  =  z .shape 
10941088
1095-         out_video  =  torch .zeros (
1096-             (B , OUT_C , N  *  4 , H  *  8 , W  *  8 ), device = z .device , dtype = z .dtype 
1097-         )
1089+         out_video  =  torch .zeros ((B , OUT_C , N  *  4 , H  *  8 , W  *  8 ), device = z .device , dtype = z .dtype )
10981090        OVERLAP  =  KERNEL [0 ] -  STRIDE [0 ], KERNEL [1 ] -  STRIDE [1 ], KERNEL [2 ] -  STRIDE [2 ]
10991091        for  i  in  range (out_n ):
11001092            n_start , n_end  =  i  *  STRIDE [0 ], i  *  STRIDE [0 ] +  KERNEL [0 ]
0 commit comments