Skip to content

Commit be67dbd

Browse files
committed
Merge branch 'main' into flux-new
2 parents 2829679 + cd6ca9d commit be67dbd

File tree

4 files changed

+22
-8
lines changed

4 files changed

+22
-8
lines changed

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ def calculate_shift(
9797
return mu
9898

9999

100+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
101+
def retrieve_latents(
102+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
103+
):
104+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
105+
return encoder_output.latent_dist.sample(generator)
106+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
107+
return encoder_output.latent_dist.mode()
108+
elif hasattr(encoder_output, "latents"):
109+
return encoder_output.latents
110+
else:
111+
raise AttributeError("Could not access latents of provided encoder_output")
112+
113+
100114
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
101115
def retrieve_timesteps(
102116
scheduler,
@@ -512,7 +526,7 @@ def prepare_latents(
512526
shape = (batch_size, num_channels_latents, height, width)
513527

514528
if latents is not None:
515-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
529+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
516530
return latents.to(device=device, dtype=dtype), latent_image_ids
517531

518532
if isinstance(generator, list) and len(generator) != batch_size:
@@ -773,7 +787,7 @@ def __call__(
773787
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
774788
if self.controlnet.input_hint_block is None:
775789
# vae encode
776-
control_image = self.vae.encode(control_image).latent_dist.sample()
790+
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
777791
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
778792

779793
# pack
@@ -811,7 +825,7 @@ def __call__(
811825

812826
if self.controlnet.nets[0].input_hint_block is None:
813827
# vae encode
814-
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
828+
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
815829
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
816830

817831
# pack

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def __call__(
801801
)
802802
height, width = control_image.shape[-2:]
803803

804-
control_image = self.vae.encode(control_image).latent_dist.sample()
804+
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
805805
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
806806

807807
height_control_image, width_control_image = control_image.shape[2:]
@@ -832,7 +832,7 @@ def __call__(
832832
)
833833
height, width = control_image_.shape[-2:]
834834

835-
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
835+
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
836836
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
837837

838838
height_control_image, width_control_image = control_image_.shape[2:]

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ def __call__(
942942
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
943943
if self.controlnet.input_hint_block is None:
944944
# vae encode
945-
control_image = self.vae.encode(control_image).latent_dist.sample()
945+
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
946946
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
947947

948948
# pack
@@ -979,7 +979,7 @@ def __call__(
979979

980980
if self.controlnet.nets[0].input_hint_block is None:
981981
# vae encode
982-
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
982+
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
983983
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
984984

985985
# pack

tests/pipelines/controlnet_flux/test_controlnet_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_controlnet_flux(self):
170170
assert image.shape == (1, 32, 32, 3)
171171

172172
expected_slice = np.array(
173-
[0.7348633, 0.41333008, 0.6621094, 0.5444336, 0.47607422, 0.5859375, 0.44677734, 0.4506836, 0.40454102]
173+
[0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156]
174174
)
175175

176176
assert (

0 commit comments

Comments
 (0)