@@ -507,10 +507,12 @@ def forward(
507507 hidden_states = self ._gradient_checkpointing_func (resnet , hidden_states , temb , generator )
508508 else :
509509 hidden_states = resnet (hidden_states , temb , generator )
510+ print (f" after resnets: { hidden_states .shape } , { hidden_states [0 ,0 ,:3 ,:3 ,:3 ]} " )
510511
511512 if self .downsamplers is not None :
512513 for downsampler in self .downsamplers :
513514 hidden_states = downsampler (hidden_states )
515+ print (f" after downsampler: { hidden_states .shape } , { hidden_states [0 ,0 ,:3 ,:3 ,:3 ]} " )
514516
515517 return hidden_states
516518
@@ -841,6 +843,8 @@ def __init__(
841843 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
842844 r"""The forward method of the `LTXVideoEncoder3d` class."""
843845
846+ print (f" inside LTXVideoEncoder3d" )
847+ print (f" hidden_states: { hidden_states .shape } , { hidden_states [0 ,0 ,:3 ,:3 ,:3 ]} " )
844848 p = self .patch_size
845849 p_t = self .patch_size_t
846850
@@ -854,7 +858,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
854858 )
855859 # Thanks for driving me insane with the weird patching order :(
856860 hidden_states = hidden_states .permute (0 , 1 , 3 , 7 , 5 , 2 , 4 , 6 ).flatten (1 , 4 )
861+ print (f" before conv_in: { hidden_states .shape } , { hidden_states [0 ,0 ,:3 ,:3 ,:3 ]} " )
857862 hidden_states = self .conv_in (hidden_states )
863+ print (f" after conv_in: { hidden_states .shape } , { hidden_states [0 ,0 ,:3 ,:3 ,:3 ]} " )
858864
859865 if torch .is_grad_enabled () and self .gradient_checkpointing :
860866 for down_block in self .down_blocks :
@@ -864,17 +870,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
864870 else :
865871 for down_block in self .down_blocks :
866872 hidden_states = down_block (hidden_states )
873+ print (f" after down_block: { hidden_states .shape } , { hidden_states [0 ,0 ,:3 ,:3 ,:3 ]} " )
867874
868875 hidden_states = self .mid_block (hidden_states )
876+ print (f" after mid_block: { hidden_states .shape } , { hidden_states [0 ,0 ,:3 ,:3 ,:3 ]} " )
869877
870878 hidden_states = self .norm_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
879+ print (f" before conv_act: { hidden_states .shape } , { hidden_states [0 ,0 ,:3 ,:3 ,:3 ]} " )
871880 hidden_states = self .conv_act (hidden_states )
881+ print (f" after conv_act: { hidden_states .shape } , { hidden_states [0 ,0 ,:3 ,:3 ,:3 ]} " )
872882 hidden_states = self .conv_out (hidden_states )
883+ print (f" after conv_out: { hidden_states .shape } , { hidden_states [0 ,0 ,:3 ,:3 ,:3 ]} " )
873884
874885 last_channel = hidden_states [:, - 1 :]
875886 last_channel = last_channel .repeat (1 , hidden_states .size (1 ) - 2 , 1 , 1 , 1 )
876887 hidden_states = torch .cat ([hidden_states , last_channel ], dim = 1 )
877-
888+ print ( f" output: { hidden_states . shape } , { hidden_states [ 0 , 0 ,: 3 ,: 3 ,: 3 ] } " )
878889 return hidden_states
879890
880891
0 commit comments