Skip to content

Commit 6e12368

Browse files
authored
Remove unused parameters and fixed FutureWarning (#6317)
* Remove unused parameters and fixed `FutureWarning` * Fixed wrong config instance * update unittest for `DDIMInverseScheduler`
1 parent f0a588b commit 6e12368

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

src/diffusers/schedulers/scheduling_ddim_inverse.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,6 @@ def step(
293293
model_output: torch.FloatTensor,
294294
timestep: int,
295295
sample: torch.FloatTensor,
296-
eta: float = 0.0,
297-
use_clipped_model_output: bool = False,
298-
variance_noise: Optional[torch.FloatTensor] = None,
299296
return_dict: bool = True,
300297
) -> Union[DDIMSchedulerOutput, Tuple]:
301298
"""
@@ -332,7 +329,7 @@ def step(
332329
# 1. get previous step value (=t+1)
333330
prev_timestep = timestep
334331
timestep = min(
335-
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.num_train_timesteps - 1
332+
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
336333
)
337334

338335
# 2. compute alphas, betas

tests/schedulers/test_scheduler_ddim_inverse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class DDIMInverseSchedulerTest(SchedulerCommonTest):
99
scheduler_classes = (DDIMInverseScheduler,)
10-
forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
10+
forward_default_kwargs = (("num_inference_steps", 50),)
1111

1212
def get_scheduler_config(self, **kwargs):
1313
config = {
@@ -26,7 +26,7 @@ def full_loop(self, **config):
2626
scheduler_config = self.get_scheduler_config(**config)
2727
scheduler = scheduler_class(**scheduler_config)
2828

29-
num_inference_steps, eta = 10, 0.0
29+
num_inference_steps = 10
3030

3131
model = self.dummy_model()
3232
sample = self.dummy_sample_deter
@@ -35,7 +35,7 @@ def full_loop(self, **config):
3535

3636
for t in scheduler.timesteps:
3737
residual = model(sample, t)
38-
sample = scheduler.step(residual, t, sample, eta).prev_sample
38+
sample = scheduler.step(residual, t, sample).prev_sample
3939

4040
return sample
4141

0 commit comments

Comments
 (0)