Skip to content

Commit 094e216

Browse files
committed
maybe
1 parent 8823139 commit 094e216

File tree

2 files changed

+47
-49
lines changed

2 files changed

+47
-49
lines changed

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_edit.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,27 @@ def __call__(
681681
generator,
682682
latents,
683683
)
684+
685+
# 5. Prepare timesteps
686+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
684687
image_seq_len = latents.shape[1]
688+
mu = calculate_shift(
689+
image_seq_len,
690+
self.scheduler.config.base_image_seq_len,
691+
self.scheduler.config.max_image_seq_len,
692+
self.scheduler.config.base_shift,
693+
self.scheduler.config.max_shift,
694+
)
695+
timesteps, num_inference_steps = retrieve_timesteps(
696+
self.scheduler,
697+
num_inference_steps,
698+
device,
699+
timesteps,
700+
sigmas,
701+
mu=mu,
702+
)
703+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
704+
self._num_timesteps = len(timesteps)
685705

686706
# handle guidance
687707
if self.transformer.config.guidance_embeds:
@@ -690,35 +710,18 @@ def __call__(
690710
else:
691711
guidance = None
692712

693-
import math
694-
def time_shift(mu: float, sigma: float, t: torch.Tensor):
695-
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
696-
697-
698-
def get_lin_function(
699-
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
700-
) -> Callable[[float], float]:
701-
m = (y2 - y1) / (x2 - x1)
702-
b = y1 - m * x1
703-
return lambda x: m * x + b
704-
705-
mu = get_lin_function()(image_seq_len)
706-
timesteps = torch.linspace(0, 1, num_inference_steps+1)
707-
timesteps = time_shift(mu, 1.0, timesteps).to(latents.device, latents.dtype)
708713
# 6. Denoising loop
709714
with self.progress_bar(total=num_inference_steps) as progress_bar:
710-
for i in range(num_inference_steps):
715+
for i, t in enumerate(timesteps):
711716
if self.interrupt:
712717
continue
713-
t = torch.tensor([timesteps[i]], device=latents.device, dtype=latents.dtype)
714-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
715718
timestep = t.expand(latents.shape[0]).to(latents.dtype)
716719
timestep = 1-timestep
717720

718721
control_guidance = controller_guidance if i < stop_step else 0.0
719722
unconditional_vector_field = -self.transformer(
720723
hidden_states=latents,
721-
timestep=timestep,
724+
timestep=timestep / 1000,
722725
guidance=guidance,
723726
pooled_projections=pooled_prompt_embeds,
724727
encoder_hidden_states=prompt_embeds,
@@ -731,10 +734,7 @@ def get_lin_function(
731734
conditional_vector_field = (reference_image - latents) / timestep
732735
controlled_vector_field = unconditional_vector_field + control_guidance * (conditional_vector_field - unconditional_vector_field)
733736

734-
sigma = timesteps[i]
735-
sigma_next = timesteps[i+1]
736-
latents = latents + (sigma_next - sigma) * controlled_vector_field
737-
737+
latents = self.scheduler.step(controlled_vector_field, t, latents, return_dict=False)[0]
738738

739739
if output_type == "latent":
740740
image = latents

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_noise.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,7 @@ def prepare_latents(
550550

551551
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
552552
import numpy as np
553-
sigma = timestep[0]
554-
latents = sigma * noise + (1.0 - sigma) * image_latents
553+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
555554
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
556555
np.save("reference_image_latent.npy", latents.detach().cpu().float().numpy())
557556
return latents, latent_image_ids
@@ -761,23 +760,26 @@ def __call__(
761760
max_sequence_length=max_sequence_length,
762761
lora_scale=lora_scale,
763762
)
764-
import math
765-
def time_shift(mu: float, sigma: float, t: torch.Tensor):
766-
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
767763

768-
769-
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
770-
def get_lin_function(
771-
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
772-
) -> Callable[[float], float]:
773-
m = (y2 - y1) / (x2 - x1)
774-
b = y1 - m * x1
775-
return lambda x: m * x + b
776-
777-
mu = get_lin_function()(image_seq_len)
778-
timesteps = torch.linspace(0, 1, num_inference_steps+1)
779-
timesteps = time_shift(mu, 1.0, timesteps).to("cuda", torch.bfloat16)
780764
# 4.Prepare timesteps
765+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
766+
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
767+
mu = calculate_shift(
768+
image_seq_len,
769+
self.scheduler.config.base_image_seq_len,
770+
self.scheduler.config.max_image_seq_len,
771+
self.scheduler.config.base_shift,
772+
self.scheduler.config.max_shift,
773+
)
774+
timesteps, num_inference_steps = retrieve_timesteps(
775+
self.scheduler,
776+
num_inference_steps,
777+
device,
778+
timesteps,
779+
sigmas,
780+
mu=mu,
781+
)
782+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
781783

782784
if num_inference_steps < 1:
783785
raise ValueError(
@@ -788,9 +790,10 @@ def get_lin_function(
788790

789791
# 5. Prepare latent variables
790792
num_channels_latents = self.transformer.config.in_channels // 4
793+
791794
latents, latent_image_ids = self.prepare_latents(
792795
init_image,
793-
timesteps,
796+
latent_timestep,
794797
batch_size * num_images_per_prompt,
795798
num_channels_latents,
796799
height,
@@ -815,15 +818,13 @@ def get_lin_function(
815818
y1 = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
816819
# 6. Denoising loop
817820
with self.progress_bar(total=num_inference_steps) as progress_bar:
818-
for i in range(num_inference_steps - stop_step):
821+
for i, t in enumerate(timesteps):
819822
if self.interrupt:
820823
continue
821-
t = torch.tensor([timesteps[i]], device=latents.device, dtype=latents.dtype)
822-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
823824
timestep = t.expand(latents.shape[0]).to(latents.dtype)
824825
noise_pred = self.transformer(
825826
hidden_states=latents,
826-
timestep=timestep,
827+
timestep=timestep / 1000,
827828
guidance=guidance,
828829
pooled_projections=pooled_prompt_embeds,
829830
encoder_hidden_states=prompt_embeds,
@@ -837,10 +838,7 @@ def get_lin_function(
837838
conditional_vector_field = (y1-latents)/(1-timestep)
838839
controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field)
839840

840-
# Get the corresponding sigma values
841-
sigma = timesteps[i]
842-
sigma_next = timesteps[i+1]
843-
latents = latents + (sigma_next - sigma) * controlled_vector_field
841+
latents = self.scheduler.step(controlled_vector_field, t, latents, return_dict=False)[0]
844842

845843
if XLA_AVAILABLE:
846844
xm.mark_step()

0 commit comments

Comments
 (0)