Skip to content

Commit d4a4d3a

Browse files
committed
Add pred_original_sample to if not return_dict path
1 parent 164ec9f commit d4a4d3a

10 files changed

+40
-10
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,10 @@ def step(
463463
prev_sample = prev_sample + variance
464464

465465
if not return_dict:
466-
return (prev_sample,)
466+
return (
467+
prev_sample,
468+
pred_original_sample,
469+
)
467470

468471
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
469472

src/diffusers/schedulers/scheduling_ddim_cogvideox.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,10 @@ def step(
394394
prev_sample = a_t * sample + b_t * pred_original_sample
395395

396396
if not return_dict:
397-
return (prev_sample,)
397+
return (
398+
prev_sample,
399+
pred_original_sample,
400+
)
398401

399402
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
400403

src/diffusers/schedulers/scheduling_ddim_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,10 @@ def step(
480480
prev_sample = prev_sample + variance
481481

482482
if not return_dict:
483-
return (prev_sample,)
483+
return (
484+
prev_sample,
485+
pred_original_sample,
486+
)
484487

485488
return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
486489

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,10 @@ def step(
492492
pred_prev_sample = pred_prev_sample + variance
493493

494494
if not return_dict:
495-
return (pred_prev_sample,)
495+
return (
496+
pred_prev_sample,
497+
pred_original_sample,
498+
)
496499

497500
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
498501

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,10 @@ def step(
500500
pred_prev_sample = pred_prev_sample + variance
501501

502502
if not return_dict:
503-
return (pred_prev_sample,)
503+
return (
504+
pred_prev_sample,
505+
pred_original_sample,
506+
)
504507

505508
return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
506509

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,10 @@ def step(
360360
self._step_index += 1
361361

362362
if not return_dict:
363-
return (prev_sample,)
363+
return (
364+
prev_sample,
365+
pred_original_sample,
366+
)
364367

365368
return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
366369

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,10 @@ def step(
435435
self._step_index += 1
436436

437437
if not return_dict:
438-
return (prev_sample,)
438+
return (
439+
prev_sample,
440+
pred_original_sample,
441+
)
439442

440443
return EulerAncestralDiscreteSchedulerOutput(
441444
prev_sample=prev_sample, pred_original_sample=pred_original_sample

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,10 @@ def step(
677677
self._step_index += 1
678678

679679
if not return_dict:
680-
return (prev_sample,)
680+
return (
681+
prev_sample,
682+
pred_original_sample,
683+
)
681684

682685
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
683686

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,10 @@ def step(
507507
self._step_index += 1
508508

509509
if not return_dict:
510-
return (prev_sample,)
510+
return (
511+
prev_sample,
512+
pred_original_sample,
513+
)
511514

512515
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
513516

src/diffusers/schedulers/scheduling_unclip.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,10 @@ def step(
320320
pred_prev_sample = pred_prev_sample + variance
321321

322322
if not return_dict:
323-
return (pred_prev_sample,)
323+
return (
324+
pred_prev_sample,
325+
pred_original_sample,
326+
)
324327

325328
return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
326329

0 commit comments

Comments
 (0)