Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -386,9 +388,9 @@ def check_inputs(
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -451,8 +453,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape

height = height // vae_scale_factor
width = width // vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // vae_scale_factor - ((int(height) // vae_scale_factor) % 2)
width = int(width) // vae_scale_factor - ((int(width) // vae_scale_factor) % 2)

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
Expand Down Expand Up @@ -501,8 +505,10 @@ def prepare_latents(
generator,
latents=None,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this same as this? (basically revert back to the original code but divisible by self.vae_scale_factor*2
a little bit easier to understand, no?

height = 2 * (int(height) // (self.vae_scale_factor*2))

width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)

Expand Down
16 changes: 10 additions & 6 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -410,9 +412,9 @@ def check_inputs(
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -500,8 +502,10 @@ def prepare_latents(
generator,
latents=None,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -453,9 +455,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -551,8 +553,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,11 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=self.vae.config.latent_channels,
do_normalize=False,
do_binarize=True,
Expand Down Expand Up @@ -467,9 +469,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -578,8 +580,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
Expand Down Expand Up @@ -624,8 +628,10 @@ def prepare_mask_latents(
device,
generator,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
Expand Down
17 changes: 10 additions & 7 deletions src/diffusers/pipelines/flux/pipeline_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -407,7 +409,6 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)

image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor

return image_latents

# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
Expand Down Expand Up @@ -437,9 +438,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -534,8 +535,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
Expand Down
24 changes: 15 additions & 9 deletions src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,11 @@ def __init__(
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=self.vae.config.latent_channels,
do_normalize=False,
do_binarize=True,
Expand Down Expand Up @@ -445,9 +447,9 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
Expand Down Expand Up @@ -555,8 +557,10 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
Expand Down Expand Up @@ -600,8 +604,10 @@ def prepare_mask_latents(
device,
generator,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = int(height) // self.vae_scale_factor - ((int(height) // self.vae_scale_factor) % 2)
width = int(width) // self.vae_scale_factor - ((int(width) // self.vae_scale_factor) % 2)
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/controlnet_flux/test_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,20 @@ def test_controlnet_flux(self):
def test_xformers_attention_forwardGenerator_pass(self):
pass

def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)

inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width = image.shape
assert (output_height, output_width) == (expected_height, expected_width)


@slow
@require_big_gpu_with_torch_cuda
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/flux/test_pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,20 @@ def test_fused_qkv_projections(self):
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."

def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)

inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)


@slow
@require_big_gpu_with_torch_cuda
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/flux/test_pipeline_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,17 @@ def test_flux_prompt_embeds(self):

max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4

def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)

inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
14 changes: 14 additions & 0 deletions tests/pipelines/flux/test_pipeline_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,17 @@ def test_flux_inpaint_prompt_embeds(self):

max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4

def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)

height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)

inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
Loading