@@ -776,17 +776,16 @@ def __init__(
776776        attn_scales : List [float ] = [],
777777        temperal_downsample : List [bool ] = [False , True , True ],
778778        dropout : float  = 0.0 ,
779+         latents_mean : List [float ] =  [- 0.7571 , - 0.7089 , - 0.9113 , 0.1075 , - 0.1745 , 0.9653 , - 0.1517 , 1.5508 , 
780+         0.4134 , - 0.0715 , 0.5517 , - 0.3632 , - 0.1922 , - 0.9497 , 0.2503 , - 0.2921 ,] ,
781+         latents_std : List [float ] =  [2.8184 , 1.4541 , 2.3275 , 2.6558 , 1.2196 , 1.7708 , 2.6052 , 2.0743 , 
782+         3.2687 , 2.1526 , 2.8652 , 1.5579 , 1.6382 , 1.1253 , 2.8251 , 1.9160 ,] ,
779783    ) ->  None :
780784        super ().__init__ ()
781785
782-         # channel-wise mean and std 
783-         mean  =  [- 0.7571 , - 0.7089 , - 0.9113 , 0.1075 , - 0.1745 , 0.9653 , - 0.1517 , 1.5508 , 
784-         0.4134 , - 0.0715 , 0.5517 , - 0.3632 , - 0.1922 , - 0.9497 , 0.2503 , - 0.2921 ,] 
785-         std  =  [2.8184 , 1.4541 , 2.3275 , 2.6558 , 1.2196 , 1.7708 , 2.6052 , 2.0743 , 
786-         3.2687 , 2.1526 , 2.8652 , 1.5579 , 1.6382 , 1.1253 , 2.8251 , 1.9160 ,] 
787786        # Store normalization parameters as tensors 
788-         self .mean  =  torch .tensor (mean )
789-         self .std  =  torch .tensor (std )
787+         self .mean  =  torch .tensor (latents_mean )
788+         self .std  =  torch .tensor (latents_std )
790789        self .scale  =  torch .stack ([self .mean , 1.0  /  self .std ])  # Shape: [2, C] 
791790
792791        self .z_dim  =  z_dim 
0 commit comments