@@ -19,8 +19,13 @@ def __call__(self, num_samples: int, device: torch.device):
1919
2020
2121class UniformDistribution (Distribution ):
22+ def __init__ (self , vmin : float = 0.0 , vmax : float = 1.0 ):
23+ super ().__init__ ()
24+ self .vmin , self .vmax = vmin , vmax
25+
2226 def __call__ (self , num_samples : int , device : torch .device = torch .device ("cpu" )):
23- return torch .rand (num_samples , device = device )
27+ vmax , vmin = self .vmax , self .vmin
28+ return (vmax - vmin ) * torch .rand (num_samples , device = device ) + vmin
2429
2530
2631""" Diffusion Methods """
@@ -132,8 +137,12 @@ def forward(self, num_steps: int, device: torch.device) -> Tensor:
132137
133138
134139class LinearSchedule (Schedule ):
140+ def __init__ (self , start : float = 1.0 , end : float = 0.0 ):
141+ super ().__init__ ()
142+ self .start , self .end = start , end
143+
135144 def forward (self , num_steps : int , device : Any ) -> Tensor :
136- return torch .linspace (1.0 , 0.0 , num_steps , device = device )
145+ return torch .linspace (self . start , self . end , num_steps , device = device )
137146
138147
139148""" Samplers """
@@ -158,14 +167,13 @@ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
158167 return alpha , beta
159168
160169 def forward ( # type: ignore
161- self , noise : Tensor , num_steps : int , show_progress : bool = False , ** kwargs
170+ self , x_noisy : Tensor , num_steps : int , show_progress : bool = False , ** kwargs
162171 ) -> Tensor :
163- b = noise .shape [0 ]
164- sigmas = self .schedule (num_steps + 1 , device = noise .device )
172+ b = x_noisy .shape [0 ]
173+ sigmas = self .schedule (num_steps + 1 , device = x_noisy .device )
165174 sigmas = repeat (sigmas , "i -> i b" , b = b )
166- sigmas_batch = extend_dim (sigmas , dim = noise .ndim + 1 )
175+ sigmas_batch = extend_dim (sigmas , dim = x_noisy .ndim + 1 )
167176 alphas , betas = self .get_alpha_beta (sigmas_batch )
168- x_noisy = noise * sigmas_batch [0 ]
169177 progress_bar = tqdm (range (num_steps ), disable = not show_progress )
170178
171179 for i in progress_bar :
0 commit comments