From d4a4d3a27c50363bb19bf0ff60ee09fda309e7af Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 11 Oct 2024 12:19:14 +0100 Subject: [PATCH] Add pred_original_sample to `if not return_dict` path --- src/diffusers/schedulers/scheduling_ddim.py | 5 ++++- src/diffusers/schedulers/scheduling_ddim_cogvideox.py | 5 ++++- src/diffusers/schedulers/scheduling_ddim_parallel.py | 5 ++++- src/diffusers/schedulers/scheduling_ddpm.py | 5 ++++- src/diffusers/schedulers/scheduling_ddpm_parallel.py | 5 ++++- src/diffusers/schedulers/scheduling_edm_euler.py | 5 ++++- .../schedulers/scheduling_euler_ancestral_discrete.py | 5 ++++- src/diffusers/schedulers/scheduling_euler_discrete.py | 5 ++++- src/diffusers/schedulers/scheduling_lms_discrete.py | 5 ++++- src/diffusers/schedulers/scheduling_unclip.py | 5 ++++- 10 files changed, 40 insertions(+), 10 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 14356eafdaea..13c9b3b4a5e9 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -463,7 +463,10 @@ def step( prev_sample = prev_sample + variance if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index ec5c5f3e1c5d..5c131752933c 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -394,7 +394,10 @@ def step( prev_sample = a_t * sample + b_t * pred_original_sample if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 0cf84b694db5..64412709ae90 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -480,7 +480,10 @@ def step( prev_sample = prev_sample + variance if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 81a770edf635..468fdf61a9ef 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -492,7 +492,10 @@ def step( pred_prev_sample = pred_prev_sample + variance if not return_dict: - return (pred_prev_sample,) + return ( + pred_prev_sample, + pred_original_sample, + ) return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index 5dfcf3c17a2f..f377ee6e8c93 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -500,7 +500,10 @@ def step( pred_prev_sample = pred_prev_sample + variance if not return_dict: - return (pred_prev_sample,) + return ( + pred_prev_sample, + pred_original_sample, + ) return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index 4b823c0d281b..04a09b114d5b 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -360,7 +360,10 @@ def step( self._step_index += 1 if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 485e919e9cc5..4df43a160ce1 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -435,7 +435,10 @@ def step( self._step_index += 1 if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) return EulerAncestralDiscreteSchedulerOutput( prev_sample=prev_sample, pred_original_sample=pred_original_sample diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 5c39583356ad..7083571a60c6 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -677,7 +677,10 @@ def step( self._step_index += 1 if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index aed8c5828c75..0a0900455488 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -507,7 +507,10 @@ def step( self._step_index += 1 if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index 6e1580290f22..22a53b0e73b6 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -320,7 +320,10 @@ def step( pred_prev_sample = pred_prev_sample + variance if not return_dict: - return (pred_prev_sample,) + return ( + pred_prev_sample, + pred_original_sample, + ) return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)