Skip to content

Commit 6f9d368

Browse files
committed
maybe
1 parent 8823139 commit 6f9d368

File tree

2 files changed

+46
-47
lines changed

2 files changed

+46
-47
lines changed

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_edit.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,27 @@ def __call__(
681681
generator,
682682
latents,
683683
)
684+
685+
# 5. Prepare timesteps
686+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
684687
image_seq_len = latents.shape[1]
688+
mu = calculate_shift(
689+
image_seq_len,
690+
self.scheduler.config.base_image_seq_len,
691+
self.scheduler.config.max_image_seq_len,
692+
self.scheduler.config.base_shift,
693+
self.scheduler.config.max_shift,
694+
)
695+
timesteps, num_inference_steps = retrieve_timesteps(
696+
self.scheduler,
697+
num_inference_steps,
698+
device,
699+
timesteps,
700+
sigmas,
701+
mu=mu,
702+
)
703+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
704+
self._num_timesteps = len(timesteps)
685705

686706
# handle guidance
687707
if self.transformer.config.guidance_embeds:
@@ -690,35 +710,18 @@ def __call__(
690710
else:
691711
guidance = None
692712

693-
import math
694-
def time_shift(mu: float, sigma: float, t: torch.Tensor):
695-
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
696-
697-
698-
def get_lin_function(
699-
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
700-
) -> Callable[[float], float]:
701-
m = (y2 - y1) / (x2 - x1)
702-
b = y1 - m * x1
703-
return lambda x: m * x + b
704-
705-
mu = get_lin_function()(image_seq_len)
706-
timesteps = torch.linspace(0, 1, num_inference_steps+1)
707-
timesteps = time_shift(mu, 1.0, timesteps).to(latents.device, latents.dtype)
708713
# 6. Denoising loop
709714
with self.progress_bar(total=num_inference_steps) as progress_bar:
710-
for i in range(num_inference_steps):
715+
for i, t in enumerate(timesteps):
711716
if self.interrupt:
712717
continue
713-
t = torch.tensor([timesteps[i]], device=latents.device, dtype=latents.dtype)
714-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
715718
timestep = t.expand(latents.shape[0]).to(latents.dtype)
716719
timestep = 1-timestep
717720

718721
control_guidance = controller_guidance if i < stop_step else 0.0
719722
unconditional_vector_field = -self.transformer(
720723
hidden_states=latents,
721-
timestep=timestep,
724+
timestep=timestep / 1000,
722725
guidance=guidance,
723726
pooled_projections=pooled_prompt_embeds,
724727
encoder_hidden_states=prompt_embeds,
@@ -731,10 +734,7 @@ def get_lin_function(
731734
conditional_vector_field = (reference_image - latents) / timestep
732735
controlled_vector_field = unconditional_vector_field + control_guidance * (conditional_vector_field - unconditional_vector_field)
733736

734-
sigma = timesteps[i]
735-
sigma_next = timesteps[i+1]
736-
latents = latents + (sigma_next - sigma) * controlled_vector_field
737-
737+
latents = self.scheduler.step(controlled_vector_field, t, latents, return_dict=False)[0]
738738

739739
if output_type == "latent":
740740
image = latents

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_noise.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -761,23 +761,26 @@ def __call__(
761761
max_sequence_length=max_sequence_length,
762762
lora_scale=lora_scale,
763763
)
764-
import math
765-
def time_shift(mu: float, sigma: float, t: torch.Tensor):
766-
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
767764

768-
769-
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
770-
def get_lin_function(
771-
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
772-
) -> Callable[[float], float]:
773-
m = (y2 - y1) / (x2 - x1)
774-
b = y1 - m * x1
775-
return lambda x: m * x + b
776-
777-
mu = get_lin_function()(image_seq_len)
778-
timesteps = torch.linspace(0, 1, num_inference_steps+1)
779-
timesteps = time_shift(mu, 1.0, timesteps).to("cuda", torch.bfloat16)
780765
# 4.Prepare timesteps
766+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
767+
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
768+
mu = calculate_shift(
769+
image_seq_len,
770+
self.scheduler.config.base_image_seq_len,
771+
self.scheduler.config.max_image_seq_len,
772+
self.scheduler.config.base_shift,
773+
self.scheduler.config.max_shift,
774+
)
775+
timesteps, num_inference_steps = retrieve_timesteps(
776+
self.scheduler,
777+
num_inference_steps,
778+
device,
779+
timesteps,
780+
sigmas,
781+
mu=mu,
782+
)
783+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
781784

782785
if num_inference_steps < 1:
783786
raise ValueError(
@@ -788,9 +791,10 @@ def get_lin_function(
788791

789792
# 5. Prepare latent variables
790793
num_channels_latents = self.transformer.config.in_channels // 4
794+
791795
latents, latent_image_ids = self.prepare_latents(
792796
init_image,
793-
timesteps,
797+
latent_timestep,
794798
batch_size * num_images_per_prompt,
795799
num_channels_latents,
796800
height,
@@ -815,15 +819,13 @@ def get_lin_function(
815819
y1 = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
816820
# 6. Denoising loop
817821
with self.progress_bar(total=num_inference_steps) as progress_bar:
818-
for i in range(num_inference_steps - stop_step):
822+
for i, t in enumerate(timesteps):
819823
if self.interrupt:
820824
continue
821-
t = torch.tensor([timesteps[i]], device=latents.device, dtype=latents.dtype)
822-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
823825
timestep = t.expand(latents.shape[0]).to(latents.dtype)
824826
noise_pred = self.transformer(
825827
hidden_states=latents,
826-
timestep=timestep,
828+
timestep=timestep / 1000,
827829
guidance=guidance,
828830
pooled_projections=pooled_prompt_embeds,
829831
encoder_hidden_states=prompt_embeds,
@@ -837,10 +839,7 @@ def get_lin_function(
837839
conditional_vector_field = (y1-latents)/(1-timestep)
838840
controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field)
839841

840-
# Get the corresponding sigma values
841-
sigma = timesteps[i]
842-
sigma_next = timesteps[i+1]
843-
latents = latents + (sigma_next - sigma) * controlled_vector_field
842+
latents = self.scheduler.step(controlled_vector_field, t, latents, return_dict=False)[0]
844843

845844
if XLA_AVAILABLE:
846845
xm.mark_step()

0 commit comments

Comments
 (0)