@@ -715,11 +715,6 @@ def __init__(
715715 ) -> None :
716716 super ().__init__ ()
717717
718- # Store normalization parameters as tensors
719- self .mean = torch .tensor (latents_mean )
720- self .std = torch .tensor (latents_std )
721- self .scale = torch .stack ([self .mean , 1.0 / self .std ]) # Shape: [2, C]
722-
723718 self .z_dim = z_dim
724719 self .temperal_downsample = temperal_downsample
725720 self .temperal_upsample = temperal_downsample [::- 1 ]
@@ -751,7 +746,6 @@ def _count_conv3d(model):
751746 self ._enc_feat_map = [None ] * self ._enc_conv_num
752747
753748 def _encode (self , x : torch .Tensor ) -> torch .Tensor :
754- scale = self .scale .type_as (x )
755749 self .clear_cache ()
756750 ## cache
757751 t = x .shape [2 ]
@@ -770,8 +764,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
770764
771765 enc = self .quant_conv (out )
772766 mu , logvar = enc [:, : self .z_dim , :, :, :], enc [:, self .z_dim :, :, :, :]
773- mu = (mu - scale [0 ].view (1 , self .z_dim , 1 , 1 , 1 )) * scale [1 ].view (1 , self .z_dim , 1 , 1 , 1 )
774- logvar = (logvar - scale [0 ].view (1 , self .z_dim , 1 , 1 , 1 )) * scale [1 ].view (1 , self .z_dim , 1 , 1 , 1 )
775767 enc = torch .cat ([mu , logvar ], dim = 1 )
776768 self .clear_cache ()
777769 return enc
@@ -798,10 +790,8 @@ def encode(
798790 return (posterior ,)
799791 return AutoencoderKLOutput (latent_dist = posterior )
800792
801- def _decode (self , z : torch .Tensor , scale , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
793+ def _decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
802794 self .clear_cache ()
803- # z: [b,c,t,h,w]
804- z = z / scale [1 ].view (1 , self .z_dim , 1 , 1 , 1 ) + scale [0 ].view (1 , self .z_dim , 1 , 1 , 1 )
805795
806796 iter_ = z .shape [2 ]
807797 x = self .post_quant_conv (z )
@@ -835,8 +825,7 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
835825 If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
836826 returned.
837827 """
838- scale = self .scale .type_as (z )
839- decoded = self ._decode (z , scale ).sample
828+ decoded = self ._decode (z ).sample
840829 if not return_dict :
841830 return (decoded ,)
842831
0 commit comments