Skip to content

Commit 0fb5fb8

Browse files
committed
fix copies
1 parent 48a4d60 commit 0fb5fb8

File tree

3 files changed

+6
-18
lines changed

3 files changed

+6
-18
lines changed

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -639,16 +639,12 @@ def __len__(self):
639639

640640
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
641641
def previous_timestep(self, timestep):
642-
if self.custom_timesteps:
642+
if self.custom_timesteps or self.num_inference_steps:
643643
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
644644
if index == self.timesteps.shape[0] - 1:
645645
prev_t = torch.tensor(-1)
646646
else:
647647
prev_t = self.timesteps[index + 1]
648648
else:
649-
num_inference_steps = (
650-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
651-
)
652-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
653-
649+
prev_t = timestep - 1
654650
return prev_t

src/diffusers/schedulers/scheduling_lcm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -643,16 +643,12 @@ def __len__(self):
643643

644644
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
645645
def previous_timestep(self, timestep):
646-
if self.custom_timesteps:
646+
if self.custom_timesteps or self.num_inference_steps:
647647
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
648648
if index == self.timesteps.shape[0] - 1:
649649
prev_t = torch.tensor(-1)
650650
else:
651651
prev_t = self.timesteps[index + 1]
652652
else:
653-
num_inference_steps = (
654-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
655-
)
656-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
657-
653+
prev_t = timestep - 1
658654
return prev_t

src/diffusers/schedulers/scheduling_tcd.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -680,16 +680,12 @@ def __len__(self):
680680

681681
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
682682
def previous_timestep(self, timestep):
683-
if self.custom_timesteps:
683+
if self.custom_timesteps or self.num_inference_steps:
684684
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
685685
if index == self.timesteps.shape[0] - 1:
686686
prev_t = torch.tensor(-1)
687687
else:
688688
prev_t = self.timesteps[index + 1]
689689
else:
690-
num_inference_steps = (
691-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
692-
)
693-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
694-
690+
prev_t = timestep - 1
695691
return prev_t

0 commit comments

Comments
 (0)