Skip to content

Commit 83ae24c

Browse files
RuiningLiDragAPartsayakpaul
authored
Added get_velocity function to EulerDiscreteScheduler. (#7733)
* Added get_velocity function to EulerDiscreteScheduler. * Fix white space on blank lines * Added copied from statement * back to the original. --------- Co-authored-by: Ruining Li <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 8af793b commit 83ae24c

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,5 +576,44 @@ def add_noise(
576576
noisy_samples = original_samples + noise * sigma
577577
return noisy_samples
578578

579+
def get_velocity(
580+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor
581+
) -> torch.FloatTensor:
582+
if (
583+
isinstance(timesteps, int)
584+
or isinstance(timesteps, torch.IntTensor)
585+
or isinstance(timesteps, torch.LongTensor)
586+
):
587+
raise ValueError(
588+
(
589+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
590+
" `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass"
591+
" one of the `scheduler.timesteps` as a timestep."
592+
),
593+
)
594+
595+
if sample.device.type == "mps" and torch.is_floating_point(timesteps):
596+
# mps does not support float64
597+
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
598+
timesteps = timesteps.to(sample.device, dtype=torch.float32)
599+
else:
600+
schedule_timesteps = self.timesteps.to(sample.device)
601+
timesteps = timesteps.to(sample.device)
602+
603+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
604+
alphas_cumprod = self.alphas_cumprod.to(sample)
605+
sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5
606+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
607+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
608+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
609+
610+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5
611+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
612+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
613+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
614+
615+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
616+
return velocity
617+
579618
def __len__(self):
580619
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)