-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[feat]Add strength in flux_fill pipeline (denoising strength for fluxfill) #10603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 11 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
14c452a
[feat]add strength in flux_fill pipeline
Suprhimp a7e1501
Update src/diffusers/pipelines/flux/pipeline_flux_fill.py
Suprhimp cf60e52
Update src/diffusers/pipelines/flux/pipeline_flux_fill.py
Suprhimp 25fa97c
Update src/diffusers/pipelines/flux/pipeline_flux_fill.py
Suprhimp 5d6b78c
[refactor] refactor after review
Suprhimp 3a1ea2e
[fix] change comment
Suprhimp a3acf58
Merge branch 'main' into main
Suprhimp f42f90f
Merge branch 'main' into main
asomoza a6737f1
Apply style fixes
github-actions[bot] c489b57
empty
asomoza e87c9eb
fix
asomoza cb43412
update prepare_latents from flux.img2img pipeline
Suprhimp f0fac62
style
asomoza a0ffed1
Update src/diffusers/pipelines/flux/pipeline_flux_fill.py
hlky b448707
Merge branch 'main' into main
asomoza File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -493,10 +493,38 @@ def encode_prompt( | |
|
|
||
| return prompt_embeds, pooled_prompt_embeds, text_ids | ||
|
|
||
| # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image | ||
| def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): | ||
| if isinstance(generator, list): | ||
| image_latents = [ | ||
| retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) | ||
| for i in range(image.shape[0]) | ||
| ] | ||
| image_latents = torch.cat(image_latents, dim=0) | ||
| else: | ||
| image_latents = retrieve_latents(self.vae.encode(image), generator=generator) | ||
|
|
||
| image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor | ||
|
|
||
| return image_latents | ||
|
|
||
| # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps | ||
| def get_timesteps(self, num_inference_steps, strength, device): | ||
| # get the original timestep using init_timestep | ||
| init_timestep = min(num_inference_steps * strength, num_inference_steps) | ||
|
|
||
| t_start = int(max(num_inference_steps - init_timestep, 0)) | ||
| timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] | ||
| if hasattr(self.scheduler, "set_begin_index"): | ||
| self.scheduler.set_begin_index(t_start * self.scheduler.order) | ||
|
|
||
| return timesteps, num_inference_steps - t_start | ||
|
|
||
| def check_inputs( | ||
| self, | ||
| prompt, | ||
| prompt_2, | ||
| strength, | ||
| height, | ||
| width, | ||
| prompt_embeds=None, | ||
|
|
@@ -507,6 +535,9 @@ def check_inputs( | |
| mask_image=None, | ||
| masked_image_latents=None, | ||
| ): | ||
| if strength < 0 or strength > 1: | ||
| raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") | ||
|
|
||
| if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: | ||
| logger.warning( | ||
| f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" | ||
|
|
@@ -624,9 +655,11 @@ def disable_vae_tiling(self): | |
| """ | ||
| self.vae.disable_tiling() | ||
|
|
||
| # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents | ||
| # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents | ||
| def prepare_latents( | ||
| self, | ||
| image, | ||
| timestep, | ||
| batch_size, | ||
| num_channels_latents, | ||
| height, | ||
|
|
@@ -636,28 +669,38 @@ def prepare_latents( | |
| generator, | ||
| latents=None, | ||
| ): | ||
| if isinstance(generator, list) and len(generator) != batch_size: | ||
| raise ValueError( | ||
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||
| ) | ||
|
|
||
| # VAE applies 8x compression on images but we must also account for packing which requires | ||
| # latent height and width to be divisible by 2. | ||
| height = 2 * (int(height) // (self.vae_scale_factor * 2)) | ||
| width = 2 * (int(width) // (self.vae_scale_factor * 2)) | ||
|
|
||
| shape = (batch_size, num_channels_latents, height, width) | ||
| latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) | ||
|
|
||
| if latents is not None: | ||
| latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) | ||
| return latents.to(device=device, dtype=dtype), latent_image_ids | ||
|
|
||
| if isinstance(generator, list) and len(generator) != batch_size: | ||
| image = image.to(device=device, dtype=dtype) | ||
| image_latents = self._encode_vae_image(image=image, generator=generator) | ||
| if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: | ||
| # expand init_latents for batch_size | ||
| additional_image_per_prompt = batch_size // image_latents.shape[0] | ||
| image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) | ||
| elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: | ||
| raise ValueError( | ||
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||
| f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." | ||
| ) | ||
| else: | ||
| image_latents = torch.cat([image_latents], dim=0) | ||
|
|
||
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | ||
| noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | ||
| latents = self.scheduler.scale_noise(image_latents, timestep, noise) | ||
| latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) | ||
|
|
||
| latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) | ||
|
|
||
| return latents, latent_image_ids | ||
|
|
||
| @property | ||
|
|
@@ -687,6 +730,7 @@ def __call__( | |
| masked_image_latents: Optional[torch.FloatTensor] = None, | ||
| height: Optional[int] = None, | ||
| width: Optional[int] = None, | ||
| strength: float = 1.0, | ||
| num_inference_steps: int = 50, | ||
| sigmas: Optional[List[float]] = None, | ||
| guidance_scale: float = 30.0, | ||
|
|
@@ -731,6 +775,12 @@ def __call__( | |
| The height in pixels of the generated image. This is set to 1024 by default for the best results. | ||
| width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||
| The width in pixels of the generated image. This is set to 1024 by default for the best results. | ||
| strength (`float`, *optional*, defaults to 1.0): | ||
| Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a | ||
| starting point and more noise is added the higher the `strength`. The number of denoising steps depends | ||
| on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising | ||
| process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 | ||
| essentially ignores `image`. | ||
| num_inference_steps (`int`, *optional*, defaults to 50): | ||
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
| expense of slower inference. | ||
|
|
@@ -794,6 +844,7 @@ def __call__( | |
| self.check_inputs( | ||
| prompt, | ||
| prompt_2, | ||
| strength, | ||
| height, | ||
| width, | ||
| prompt_embeds=prompt_embeds, | ||
|
|
@@ -809,6 +860,9 @@ def __call__( | |
| self._joint_attention_kwargs = joint_attention_kwargs | ||
| self._interrupt = False | ||
|
|
||
| init_image = self.image_processor.preprocess(image, height=height, width=width) | ||
| init_image = init_image.to(dtype=torch.float32) | ||
|
|
||
| # 2. Define call parameters | ||
| if prompt is not None and isinstance(prompt, str): | ||
| batch_size = 1 | ||
|
|
@@ -838,9 +892,37 @@ def __call__( | |
| lora_scale=lora_scale, | ||
| ) | ||
|
|
||
| # 4. Prepare latent variables | ||
| # 4. Prepare timesteps | ||
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas | ||
| image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) | ||
| mu = calculate_shift( | ||
| image_seq_len, | ||
| self.scheduler.config.base_image_seq_len, | ||
| self.scheduler.config.max_image_seq_len, | ||
| self.scheduler.config.base_shift, | ||
| self.scheduler.config.max_shift, | ||
| ) | ||
hlky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| timesteps, num_inference_steps = retrieve_timesteps( | ||
| self.scheduler, | ||
| num_inference_steps, | ||
| device, | ||
| sigmas=sigmas, | ||
| mu=mu, | ||
| ) | ||
| timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | ||
|
|
||
| if num_inference_steps < 1: | ||
| raise ValueError( | ||
| f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" | ||
| f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." | ||
| ) | ||
| latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | ||
|
|
||
| # 5. Prepare latent variables | ||
| num_channels_latents = self.vae.config.latent_channels | ||
| latents, latent_image_ids = self.prepare_latents( | ||
| init_image, | ||
| latent_timestep, | ||
| batch_size * num_images_per_prompt, | ||
| num_channels_latents, | ||
| height, | ||
|
|
@@ -851,17 +933,16 @@ def __call__( | |
| latents, | ||
| ) | ||
|
|
||
| # 5. Prepare mask and masked image latents | ||
| # 6. Prepare mask and masked image latents | ||
| if masked_image_latents is not None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Above |
||
| masked_image_latents = masked_image_latents.to(latents.device) | ||
| else: | ||
| image = self.image_processor.preprocess(image, height=height, width=width) | ||
| mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) | ||
|
|
||
| masked_image = image * (1 - mask_image) | ||
| masked_image = init_image * (1 - mask_image) | ||
| masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) | ||
|
|
||
| height, width = image.shape[-2:] | ||
| height, width = init_image.shape[-2:] | ||
| mask, masked_image_latents = self.prepare_mask_latents( | ||
| mask_image, | ||
| masked_image, | ||
|
|
@@ -876,23 +957,6 @@ def __call__( | |
| ) | ||
| masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) | ||
|
|
||
| # 6. Prepare timesteps | ||
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas | ||
| image_seq_len = latents.shape[1] | ||
| mu = calculate_shift( | ||
| image_seq_len, | ||
| self.scheduler.config.get("base_image_seq_len", 256), | ||
| self.scheduler.config.get("max_image_seq_len", 4096), | ||
| self.scheduler.config.get("base_shift", 0.5), | ||
| self.scheduler.config.get("max_shift", 1.15), | ||
| ) | ||
| timesteps, num_inference_steps = retrieve_timesteps( | ||
| self.scheduler, | ||
| num_inference_steps, | ||
| device, | ||
| sigmas=sigmas, | ||
| mu=mu, | ||
| ) | ||
| num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | ||
| self._num_timesteps = len(timesteps) | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this part it's not the same as the
prepare_latentsin FluxImg2ImgPipeline but you're commenting that this function was copied from it.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detail review, and yes, since FluxImg2ImgPipeline was updated 2 weeks ago, It has to be changed.
So I updated
prepare_latentsfunction as same with in FluxImg2ImgPipeline.And also I tested with my code, it works well. (honestly, it works better than before i think, lol)