diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 14356eafdaea..bd567babd7c7 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -231,7 +231,9 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + # TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ @@ -251,8 +253,12 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None return sample def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) + + safe_prev_timestep = torch.clamp(prev_timestep, min=0) + safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) + beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -338,6 +344,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.timesteps = torch.from_numpy(timesteps).to(device) + self.alphas_cumprod = self.alphas_cumprod.to(device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(device) def step( self, @@ -402,8 +410,11 @@ def step( prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) + + safe_prev_timestep = torch.clamp(prev_timestep, min=0) + safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index ec5c5f3e1c5d..20a5f1a75f33 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -228,11 +228,17 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + # TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) + + safe_prev_timestep = torch.clamp(prev_timestep, min=0) + safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) + beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -301,6 +307,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.timesteps = torch.from_numpy(timesteps).to(device) + self.alphas_cumprod = self.alphas_cumprod.to(device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(device) def step( self, @@ -365,8 +373,11 @@ def step( prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep) + + safe_prev_timestep = torch.clamp(prev_timestep, min=0) + safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t