Skip to content

Commit a3ed155

Browse files
committed
maybe
1 parent d78598e commit a3ed155

File tree

2 files changed

+24
-57
lines changed

2 files changed

+24
-57
lines changed

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_edit.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -627,10 +627,6 @@ def __call__(
627627
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
628628
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
629629
max_sequence_length: int = 512,
630-
sigmas = None,
631-
flip_schedule = False,
632-
even_timesteps = None,
633-
divide_timestep = True,
634630
):
635631
r"""
636632
Function invoked when calling the pipeline for generation.
@@ -767,6 +763,7 @@ def __call__(
767763
)
768764

769765
# 4.Prepare timesteps
766+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
770767
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
771768
mu = calculate_shift(
772769
image_seq_len,
@@ -783,9 +780,9 @@ def __call__(
783780
sigmas,
784781
mu=mu,
785782
)
786-
if flip_schedule:
787-
self.scheduler.sigmas = self.scheduler.sigmas.flip(0)
788-
self.scheduler.timesteps = self.scheduler.timesteps.flip(0)
783+
self.scheduler.sigmas = self.scheduler.sigmas.flip(0)
784+
self.scheduler.timesteps = self.scheduler.timesteps.flip(0)
785+
self.scheduler.sigmas[0] += 1e-6
789786
print(f"self.scheduler.sigmas {self.scheduler.sigmas}")
790787
print(f"self.scheduler.timesteps {self.scheduler.timesteps}")
791788
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
@@ -840,21 +837,14 @@ def __call__(
840837

841838
# 6. Denoising loop
842839
with self.progress_bar(total=num_inference_steps) as progress_bar:
843-
for i, t in enumerate(timesteps):
840+
for i in range(self._num_timesteps - 1):
844841
if self.interrupt:
845842
continue
846-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
847-
if even_timesteps is None:
848-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
849-
if divide_timestep:
850-
timestep = timestep / 1000
851-
else:
852-
timestep = torch.tensor([even_timesteps[i]], device=latents.device, dtype=latents.dtype)
843+
timestep = torch.tensor([self.scheduler.sigmas[i]], device=latents.device, dtype=latents.dtype)
853844
# Unconditional vector field: $v_{t_i}(X_{t_i}) = -u(X_{t_i}, 1 - t_i, \Phi(\text{prompt}); \phi)$
854-
timestep = 1-timestep
855845
unconditional_vector_field = -self.transformer(
856846
hidden_states=latents,
857-
timestep=timestep if divide_timestep else timestep / 1000,
847+
timestep=timestep,
858848
guidance=guidance,
859849
pooled_projections=pooled_prompt_embeds,
860850
encoder_hidden_states=prompt_embeds,
@@ -865,31 +855,24 @@ def __call__(
865855
)[0]
866856

867857
# consider a time-varying controller guidance schedule ηt = η ∀t ≤ τ and 0 otherwise
868-
control_guidance = controller_guidance if i < stopping_time else 0.0
869858
# Conditional vector field: $v_{t_i}(X_{t_i} | y_0) = \frac{y_0 - X_{t_i}}{1 - t_i}$
870-
conditional_vector_field = (reference_image - latents) / timestep
859+
t_i = i / self._num_timesteps
860+
conditional_vector_field = (reference_image - latents) / (1-t_i)
871861
# Controlled vector field: $\hat{v}_{t_i}(X_{t_i}) = v_{t_i}(X_{t_i}) + \eta \left( v_{t_i}(X_{t_i} | y_0) - v_{t_i}(X_{t_i}) \right)$
872-
controlled_vector_field = unconditional_vector_field + control_guidance * (conditional_vector_field - unconditional_vector_field)
862+
controlled_vector_field = unconditional_vector_field
863+
if i < stopping_time:
864+
controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field)
873865

874866
# compute the previous noisy sample x_t -> x_t-1
875867
latents_dtype = latents.dtype
876868
# Next state: $X_{t_{i+1}} = X_{t_i} + \hat{v}_{t_i}(X_{t_i}) \cdot (\sigma(t_{i+1}) - \sigma(t_i))$
877-
latents = self.scheduler.step(controlled_vector_field, t, latents, return_dict=False)[0]
869+
latents = latents + controlled_vector_field * (self.scheduler.sigmas[i+1] - self.scheduler.sigmas[i])
878870

879871
if latents.dtype != latents_dtype:
880872
if torch.backends.mps.is_available():
881873
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
882874
latents = latents.to(latents_dtype)
883875

884-
if callback_on_step_end is not None:
885-
callback_kwargs = {}
886-
for k in callback_on_step_end_tensor_inputs:
887-
callback_kwargs[k] = locals()[k]
888-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
889-
890-
latents = callback_outputs.pop("latents", latents)
891-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
892-
893876
# call the callback, if provided
894877
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
895878
progress_bar.update()

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_noise.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -626,10 +626,6 @@ def __call__(
626626
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
627627
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
628628
max_sequence_length: int = 512,
629-
sigmas = None,
630-
flip_schedule = False,
631-
even_timesteps = None,
632-
divide_timestep = True
633629
):
634630
r"""
635631
Function invoked when calling the pipeline for generation.
@@ -767,7 +763,7 @@ def __call__(
767763
)
768764

769765
# 4.Prepare timesteps
770-
# Flux noise scheduler $\sigma : [0, 1] \to \mathbb{R}$
766+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
771767
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
772768
mu = calculate_shift(
773769
image_seq_len,
@@ -784,9 +780,9 @@ def __call__(
784780
sigmas,
785781
mu=mu,
786782
)
787-
if flip_schedule:
788-
self.scheduler.sigmas = self.scheduler.sigmas.flip(0)
789-
self.scheduler.timesteps = self.scheduler.timesteps.flip(0)
783+
self.scheduler.sigmas = self.scheduler.sigmas.flip(0)
784+
self.scheduler.timesteps = self.scheduler.timesteps.flip(0)
785+
self.scheduler.sigmas[0] += 1e-6
790786
print(f"self.scheduler.sigmas {self.scheduler.sigmas}")
791787
print(f"self.scheduler.timesteps {self.scheduler.timesteps}")
792788
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
@@ -828,22 +824,18 @@ def __call__(
828824
y1 = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
829825
# 6. Denoising loop
830826
with self.progress_bar(total=num_inference_steps) as progress_bar:
831-
for i, t in enumerate(timesteps):
827+
for i in range(self._num_timesteps - 1):
832828
# starting time s ∈ [0, 1] is defined as the time at which our controlled reverse ODE (15) is initialized.
833829
# The initial state Xs = y1−s is obtained by integrating the controlled forward ODE (8) from 0 → 1 − s.
834830
if i > self._num_timesteps - starting_time:
835831
continue
836832
if self.interrupt:
837833
continue
838-
if even_timesteps is None:
839-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
840-
if divide_timestep:
841-
timestep = timestep / 1000
842-
else:
843-
timestep = torch.tensor([even_timesteps[i]], device=latents.device, dtype=latents.dtype)
834+
timestep = torch.tensor([self.scheduler.sigmas[i]], device=latents.device, dtype=latents.dtype)
835+
844836
noise_pred = self.transformer(
845837
hidden_states=latents,
846-
timestep=timestep if divide_timestep else timestep / 1000,
838+
timestep=timestep,
847839
guidance=guidance,
848840
pooled_projections=pooled_prompt_embeds,
849841
encoder_hidden_states=prompt_embeds,
@@ -856,28 +848,20 @@ def __call__(
856848
# Unconditional vector field: $u_{t_i}(Y_{t_i}) = u(Y_{t_i}, t_i, \Phi(\text{""}); \phi)$
857849
unconditional_vector_field = noise_pred
858850
# Conditional vector field: $u_{t_i}(Y_{t_i} | y_1) = \frac{y_1 - Y_{t_i}}{1 - t_i}$
859-
conditional_vector_field = (y1-latents)/(1-timestep)
851+
t_i = i / self._num_timesteps
852+
conditional_vector_field = (y1-latents)/(1-t_i)
860853
# Controlled vector field: $\hat{u}_{t_i}(Y_{t_i}) = u_{t_i}(Y_{t_i}) + \gamma \left( u_{t_i}(Y_{t_i} | y_1) - u_{t_i}(Y_{t_i}) \right)$
861854
controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field)
862855

863856
latents_dtype = latents.dtype
864857
# Next state: $Y_{t_{i+1}} = Y_{t_i} + \hat{u}_{t_i}(Y_{t_i}) \cdot (\sigma(t_{i+1}) - \sigma(t_i))$
865-
latents = self.scheduler.step(controlled_vector_field, t, latents, return_dict=False)[0]
858+
latents = latents + controlled_vector_field * (self.scheduler.sigmas[i] - self.scheduler.sigmas[i+1])
866859

867860
if latents.dtype != latents_dtype:
868861
if torch.backends.mps.is_available():
869862
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
870863
latents = latents.to(latents_dtype)
871864

872-
if callback_on_step_end is not None:
873-
callback_kwargs = {}
874-
for k in callback_on_step_end_tensor_inputs:
875-
callback_kwargs[k] = locals()[k]
876-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
877-
878-
latents = callback_outputs.pop("latents", latents)
879-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
880-
881865
# call the callback, if provided
882866
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
883867
progress_bar.update()

0 commit comments

Comments
 (0)