From 6aaa26856b981a79f03e67e900ade33039adef0d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 11 Dec 2024 23:23:28 +0200 Subject: [PATCH 1/3] add decay --- examples/community/pipeline_flux_rf_inversion.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index f09160c4571d..00ee637dd2cc 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -880,12 +880,10 @@ def __call__( v_t = -noise_pred v_t_cond = (y_0 - latents) / (1 - t_i) eta_t = eta if start_timestep <= i < stop_timestep else 0.0 - if start_timestep <= i < stop_timestep: - # controlled vector field - v_hat_t = v_t + eta * (v_t_cond - v_t) - - else: - v_hat_t = v_t + if decay_eta: + eta_t = eta_t * (1 - i / num_inference_steps) # Decay eta over the loop + # eta_t = eta * (1 - i / num_inference_steps) ** 2 + v_hat_t = v_t + eta_t * (v_t_cond - v_t) # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792 latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1]) From baae6e759d8e23e7f5d4038c6503843737c3d56b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 12 Dec 2024 14:32:06 +0200 Subject: [PATCH 2/3] add decay --- examples/community/pipeline_flux_rf_inversion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 00ee637dd2cc..d5ca9a538af9 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -648,6 +648,8 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, eta: float = 1.0, + decay_eta: Optional[bool] = False, + eta_decay_power: Optional[float] = 1., strength: float = 1.0, start_timestep: float = 0, stop_timestep: float = 0.25, @@ -881,8 +883,7 @@ def __call__( v_t_cond = (y_0 - latents) / (1 - t_i) eta_t = eta if start_timestep <= i < stop_timestep else 0.0 if decay_eta: - eta_t = eta_t * (1 - i / num_inference_steps) # Decay eta over the loop - # eta_t = eta * (1 - i / num_inference_steps) ** 2 + eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loop v_hat_t = v_t + eta_t * (v_t_cond - v_t) # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792 From 3e352fe3ff452e2fbcc9b21a7e727110f0886a37 Mon Sep 17 00:00:00 2001 From: Linoy Date: Fri, 13 Dec 2024 08:51:14 +0000 Subject: [PATCH 3/3] style --- examples/community/pipeline_flux_rf_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index d5ca9a538af9..c8a87a426dc0 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -649,7 +649,7 @@ def __call__( width: Optional[int] = None, eta: float = 1.0, decay_eta: Optional[bool] = False, - eta_decay_power: Optional[float] = 1., + eta_decay_power: Optional[float] = 1.0, strength: float = 1.0, start_timestep: float = 0, stop_timestep: float = 0.25,