@@ -74,51 +74,3 @@ class CausalEncoderOutput(NamedTuple):
7474
7575class CausalDecoderOutput (NamedTuple ):
7676 sample : torch .Tensor
77-
78-
79- class DecoderOutput :
80- """Output of decoding method - matches diffusers.models.autoencoders.vae.DecoderOutput"""
81- def __init__ (self , sample : torch .Tensor , commit_loss : Optional [torch .Tensor ] = None ):
82- self .sample = sample
83- self .commit_loss = commit_loss
84-
85-
86- class DiagonalGaussianDistribution :
87- """Matches diffusers.models.autoencoders.vae.DiagonalGaussianDistribution exactly."""
88- def __init__ (self , parameters : torch .Tensor , deterministic : bool = False ):
89- self .parameters = parameters
90- self .mean , self .logvar = torch .chunk (parameters , 2 , dim = 1 )
91- self .logvar = torch .clamp (self .logvar , - 30.0 , 20.0 )
92- self .deterministic = deterministic
93- self .std = torch .exp (0.5 * self .logvar )
94- self .var = torch .exp (self .logvar )
95- if self .deterministic :
96- self .var = self .std = torch .zeros_like (
97- self .mean , device = self .parameters .device , dtype = self .parameters .dtype
98- )
99-
100- def sample (self , generator : Optional [torch .Generator ] = None ) -> torch .Tensor :
101- if self .deterministic :
102- return self .mode ()
103- sample = torch .randn (
104- self .mean .shape ,
105- generator = generator ,
106- device = self .parameters .device ,
107- dtype = self .parameters .dtype ,
108- )
109- return self .mean + self .std * sample
110-
111- def mode (self ) -> torch .Tensor :
112- return self .mean
113-
114- def kl (self , other : Optional ["DiagonalGaussianDistribution" ] = None ) -> torch .Tensor :
115- if other is None :
116- return 0.5 * torch .sum (
117- self .mean .pow (2 ) + self .var - 1.0 - self .logvar ,
118- dim = [1 , 2 , 3 ],
119- )
120- return 0.5 * torch .sum (
121- (self .mean - other .mean ).pow (2 ) / other .var
122- + self .var / other .var - 1.0 - self .logvar + other .logvar ,
123- dim = [1 , 2 , 3 ],
124- )
0 commit comments