Skip to content

Commit 8823139

Browse files
committed
maybe
1 parent 54ca52b commit 8823139

File tree

2 files changed

+86
-98
lines changed

2 files changed

+86
-98
lines changed

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_edit.py

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ def __call__(
539539
guidance_scale: float = 3.5,
540540
controller_guidance: float = 5.0,
541541
reference_image: Optional[torch.FloatTensor] = None,
542+
stop_step: int = 28,
542543
num_images_per_prompt: Optional[int] = 1,
543544
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
544545
latents: Optional[torch.FloatTensor] = None,
@@ -680,27 +681,7 @@ def __call__(
680681
generator,
681682
latents,
682683
)
683-
684-
# 5. Prepare timesteps
685-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
686684
image_seq_len = latents.shape[1]
687-
mu = calculate_shift(
688-
image_seq_len,
689-
self.scheduler.config.base_image_seq_len,
690-
self.scheduler.config.max_image_seq_len,
691-
self.scheduler.config.base_shift,
692-
self.scheduler.config.max_shift,
693-
)
694-
timesteps, num_inference_steps = retrieve_timesteps(
695-
self.scheduler,
696-
num_inference_steps,
697-
device,
698-
timesteps,
699-
sigmas,
700-
mu=mu,
701-
)
702-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
703-
self._num_timesteps = len(timesteps)
704685

705686
# handle guidance
706687
if self.transformer.config.guidance_embeds:
@@ -709,18 +690,35 @@ def __call__(
709690
else:
710691
guidance = None
711692

693+
import math
694+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
695+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
696+
697+
698+
def get_lin_function(
699+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
700+
) -> Callable[[float], float]:
701+
m = (y2 - y1) / (x2 - x1)
702+
b = y1 - m * x1
703+
return lambda x: m * x + b
704+
705+
mu = get_lin_function()(image_seq_len)
706+
timesteps = torch.linspace(0, 1, num_inference_steps+1)
707+
timesteps = time_shift(mu, 1.0, timesteps).to(latents.device, latents.dtype)
712708
# 6. Denoising loop
713709
with self.progress_bar(total=num_inference_steps) as progress_bar:
714-
for i, t in enumerate(timesteps):
710+
for i in range(num_inference_steps):
715711
if self.interrupt:
716712
continue
717-
713+
t = torch.tensor([timesteps[i]], device=latents.device, dtype=latents.dtype)
718714
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
719715
timestep = t.expand(latents.shape[0]).to(latents.dtype)
716+
timestep = 1-timestep
720717

721-
noise_pred = self.transformer(
718+
control_guidance = controller_guidance if i < stop_step else 0.0
719+
unconditional_vector_field = -self.transformer(
722720
hidden_states=latents,
723-
timestep=(1-(timestep / 1000)),
721+
timestep=timestep,
724722
guidance=guidance,
725723
pooled_projections=pooled_prompt_embeds,
726724
encoder_hidden_states=prompt_embeds,
@@ -730,34 +728,13 @@ def __call__(
730728
return_dict=False,
731729
)[0]
732730

733-
unconditional_vector_field = -noise_pred
734-
conditional_vector_field = (reference_image-unconditional_vector_field)/(1-(timestep / 1000))
735-
controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field)
736-
737-
# compute the previous noisy sample x_t -> x_t-1
738-
latents_dtype = latents.dtype
739-
latents = self.scheduler.step(controlled_vector_field, t, latents, return_dict=False)[0]
740-
741-
if latents.dtype != latents_dtype:
742-
if torch.backends.mps.is_available():
743-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
744-
latents = latents.to(latents_dtype)
745-
746-
if callback_on_step_end is not None:
747-
callback_kwargs = {}
748-
for k in callback_on_step_end_tensor_inputs:
749-
callback_kwargs[k] = locals()[k]
750-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
751-
752-
latents = callback_outputs.pop("latents", latents)
753-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
731+
conditional_vector_field = (reference_image - latents) / timestep
732+
controlled_vector_field = unconditional_vector_field + control_guidance * (conditional_vector_field - unconditional_vector_field)
754733

755-
# call the callback, if provided
756-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
757-
progress_bar.update()
734+
sigma = timesteps[i]
735+
sigma_next = timesteps[i+1]
736+
latents = latents + (sigma_next - sigma) * controlled_vector_field
758737

759-
if XLA_AVAILABLE:
760-
xm.mark_step()
761738

762739
if output_type == "latent":
763740
image = latents

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_noise.py

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,11 @@ def prepare_latents(
549549
image_latents = torch.cat([image_latents], dim=0)
550550

551551
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
552-
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
552+
import numpy as np
553+
sigma = timestep[0]
554+
latents = sigma * noise + (1.0 - sigma) * image_latents
553555
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
556+
np.save("reference_image_latent.npy", latents.detach().cpu().float().numpy())
554557
return latents, latent_image_ids
555558

556559
@property
@@ -569,6 +572,35 @@ def num_timesteps(self):
569572
def interrupt(self):
570573
return self._interrupt
571574

575+
def enable_vae_slicing(self):
576+
r"""
577+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
578+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
579+
"""
580+
self.vae.enable_slicing()
581+
582+
def disable_vae_slicing(self):
583+
r"""
584+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
585+
computing decoding in one step.
586+
"""
587+
self.vae.disable_slicing()
588+
589+
def enable_vae_tiling(self):
590+
r"""
591+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
592+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
593+
processing larger images.
594+
"""
595+
self.vae.enable_tiling()
596+
597+
def disable_vae_tiling(self):
598+
r"""
599+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
600+
computing decoding in one step.
601+
"""
602+
self.vae.disable_tiling()
603+
572604
@torch.no_grad()
573605
@replace_example_docstring(EXAMPLE_DOC_STRING)
574606
def __call__(
@@ -582,7 +614,8 @@ def __call__(
582614
num_inference_steps: int = 28,
583615
timesteps: List[int] = None,
584616
guidance_scale: float = 7.0,
585-
controller_guidance: float = 5.0,
617+
controller_guidance: float = 0.5,
618+
stop_step: int = 0,
586619
num_images_per_prompt: Optional[int] = 1,
587620
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
588621
latents: Optional[torch.FloatTensor] = None,
@@ -728,26 +761,23 @@ def __call__(
728761
max_sequence_length=max_sequence_length,
729762
lora_scale=lora_scale,
730763
)
764+
import math
765+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
766+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
767+
731768

732-
# 4.Prepare timesteps
733-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
734769
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
735-
mu = calculate_shift(
736-
image_seq_len,
737-
self.scheduler.config.base_image_seq_len,
738-
self.scheduler.config.max_image_seq_len,
739-
self.scheduler.config.base_shift,
740-
self.scheduler.config.max_shift,
741-
)
742-
timesteps, num_inference_steps = retrieve_timesteps(
743-
self.scheduler,
744-
num_inference_steps,
745-
device,
746-
timesteps,
747-
sigmas,
748-
mu=mu,
749-
)
750-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
770+
def get_lin_function(
771+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
772+
) -> Callable[[float], float]:
773+
m = (y2 - y1) / (x2 - x1)
774+
b = y1 - m * x1
775+
return lambda x: m * x + b
776+
777+
mu = get_lin_function()(image_seq_len)
778+
timesteps = torch.linspace(0, 1, num_inference_steps+1)
779+
timesteps = time_shift(mu, 1.0, timesteps).to("cuda", torch.bfloat16)
780+
# 4.Prepare timesteps
751781

752782
if num_inference_steps < 1:
753783
raise ValueError(
@@ -758,10 +788,9 @@ def __call__(
758788

759789
# 5. Prepare latent variables
760790
num_channels_latents = self.transformer.config.in_channels // 4
761-
762791
latents, latent_image_ids = self.prepare_latents(
763792
init_image,
764-
latent_timestep,
793+
timesteps,
765794
batch_size * num_images_per_prompt,
766795
num_channels_latents,
767796
height,
@@ -784,18 +813,17 @@ def __call__(
784813

785814
# fix noise sample y1
786815
y1 = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
787-
788816
# 6. Denoising loop
789817
with self.progress_bar(total=num_inference_steps) as progress_bar:
790-
for i, t in enumerate(timesteps):
818+
for i in range(num_inference_steps - stop_step):
791819
if self.interrupt:
792820
continue
793-
821+
t = torch.tensor([timesteps[i]], device=latents.device, dtype=latents.dtype)
794822
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
795823
timestep = t.expand(latents.shape[0]).to(latents.dtype)
796824
noise_pred = self.transformer(
797825
hidden_states=latents,
798-
timestep=timestep / 1000,
826+
timestep=timestep,
799827
guidance=guidance,
800828
pooled_projections=pooled_prompt_embeds,
801829
encoder_hidden_states=prompt_embeds,
@@ -806,30 +834,13 @@ def __call__(
806834
)[0]
807835

808836
unconditional_vector_field = noise_pred
809-
conditional_vector_field = (y1-unconditional_vector_field)/(1-(timestep / 1000))
837+
conditional_vector_field = (y1-latents)/(1-timestep)
810838
controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field)
811839

812-
# compute the previous noisy sample x_t -> x_t-1
813-
latents_dtype = latents.dtype
814-
latents = self.scheduler.step(controlled_vector_field, t, latents, return_dict=False)[0]
815-
816-
if latents.dtype != latents_dtype:
817-
if torch.backends.mps.is_available():
818-
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
819-
latents = latents.to(latents_dtype)
820-
821-
if callback_on_step_end is not None:
822-
callback_kwargs = {}
823-
for k in callback_on_step_end_tensor_inputs:
824-
callback_kwargs[k] = locals()[k]
825-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
826-
827-
latents = callback_outputs.pop("latents", latents)
828-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
829-
830-
# call the callback, if provided
831-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
832-
progress_bar.update()
840+
# Get the corresponding sigma values
841+
sigma = timesteps[i]
842+
sigma_next = timesteps[i+1]
843+
latents = latents + (sigma_next - sigma) * controlled_vector_field
833844

834845
if XLA_AVAILABLE:
835846
xm.mark_step()

0 commit comments

Comments
 (0)