Skip to content

Commit be58660

Browse files
committed
remove fluxcontrolmixin
1 parent cf3053b commit be58660

7 files changed

+47
-49
lines changed

src/diffusers/pipelines/flux/pipeline_flux_control.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131
from ...utils.torch_utils import randn_tensor
3232
from ..pipeline_utils import DiffusionPipeline
33-
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_timesteps
33+
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_timesteps
3434
from .pipeline_output import FluxPipelineOutput
3535

3636

@@ -81,7 +81,7 @@
8181

8282
class FluxControlPipeline(
8383
DiffusionPipeline,
84-
FluxControlMixin,
84+
FluxMixin,
8585
FluxLoraLoaderMixin,
8686
FromSingleFileMixin,
8787
TextualInversionLoaderMixin,
@@ -235,6 +235,41 @@ def prepare_latents(
235235

236236
return latents, latent_image_ids
237237

238+
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
239+
def prepare_image(
240+
self,
241+
image,
242+
width,
243+
height,
244+
batch_size,
245+
num_images_per_prompt,
246+
device,
247+
dtype,
248+
do_classifier_free_guidance=False,
249+
guess_mode=False,
250+
):
251+
if isinstance(image, torch.Tensor):
252+
pass
253+
else:
254+
image = self.image_processor.preprocess(image, height=height, width=width)
255+
256+
image_batch_size = image.shape[0]
257+
258+
if image_batch_size == 1:
259+
repeat_by = batch_size
260+
else:
261+
# image batch size is the same as prompt batch size
262+
repeat_by = num_images_per_prompt
263+
264+
image = image.repeat_interleave(repeat_by, dim=0)
265+
266+
image = image.to(device=device, dtype=dtype)
267+
268+
if do_classifier_free_guidance and not guess_mode:
269+
image = torch.cat([image] * 2)
270+
271+
return image
272+
238273
@torch.no_grad()
239274
@replace_example_docstring(EXAMPLE_DOC_STRING)
240275
def __call__(

src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2727
from ...utils.torch_utils import randn_tensor
2828
from ..pipeline_utils import DiffusionPipeline
29-
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
29+
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
3030
from .pipeline_output import FluxPipelineOutput
3131

3232

@@ -80,7 +80,7 @@
8080
"""
8181

8282

83-
class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
83+
class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
8484
r"""
8585
The Flux pipeline for image inpainting.
8686

src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from ...utils import is_torch_xla_available, logging, replace_example_docstring
3636
from ...utils.torch_utils import randn_tensor
3737
from ..pipeline_utils import DiffusionPipeline
38-
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
38+
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
3939
from .pipeline_output import FluxPipelineOutput
4040

4141

@@ -108,7 +108,7 @@
108108

109109
class FluxControlInpaintPipeline(
110110
DiffusionPipeline,
111-
FluxControlMixin,
111+
FluxMixin,
112112
FluxLoraLoaderMixin,
113113
FromSingleFileMixin,
114114
TextualInversionLoaderMixin,

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ...utils import is_torch_xla_available, logging, replace_example_docstring
3535
from ...utils.torch_utils import randn_tensor
3636
from ..pipeline_utils import DiffusionPipeline
37-
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
37+
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
3838
from .pipeline_output import FluxPipelineOutput
3939

4040

@@ -80,7 +80,7 @@
8080

8181

8282
class FluxControlNetPipeline(
83-
DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin
83+
DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin
8484
):
8585
r"""
8686
The Flux pipeline for text-to-image generation.

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ...utils import is_torch_xla_available, logging, replace_example_docstring
1919
from ...utils.torch_utils import randn_tensor
2020
from ..pipeline_utils import DiffusionPipeline
21-
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
21+
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
2222
from .pipeline_output import FluxPipelineOutput
2323

2424

@@ -74,7 +74,7 @@
7474
"""
7575

7676

77-
class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
77+
class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
7878
r"""
7979
The Flux controlnet pipeline for image-to-image generation.
8080

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2020
from ...utils.torch_utils import randn_tensor
2121
from ..pipeline_utils import DiffusionPipeline
22-
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
22+
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
2323
from .pipeline_output import FluxPipelineOutput
2424

2525

@@ -76,7 +76,7 @@
7676
"""
7777

7878

79-
class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
79+
class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
8080
r"""
8181
The Flux controlnet pipeline for inpainting.
8282

src/diffusers/pipelines/flux/pipeline_flux_utils.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -394,40 +394,3 @@ def _get_clip_prompt_embeds(
394394
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
395395

396396
return prompt_embeds
397-
398-
399-
class FluxControlMixin(FluxMixin):
400-
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
401-
def prepare_image(
402-
self,
403-
image,
404-
width,
405-
height,
406-
batch_size,
407-
num_images_per_prompt,
408-
device,
409-
dtype,
410-
do_classifier_free_guidance=False,
411-
guess_mode=False,
412-
):
413-
if isinstance(image, torch.Tensor):
414-
pass
415-
else:
416-
image = self.image_processor.preprocess(image, height=height, width=width)
417-
418-
image_batch_size = image.shape[0]
419-
420-
if image_batch_size == 1:
421-
repeat_by = batch_size
422-
else:
423-
# image batch size is the same as prompt batch size
424-
repeat_by = num_images_per_prompt
425-
426-
image = image.repeat_interleave(repeat_by, dim=0)
427-
428-
image = image.to(device=device, dtype=dtype)
429-
430-
if do_classifier_free_guidance and not guess_mode:
431-
image = torch.cat([image] * 2)
432-
433-
return image

0 commit comments

Comments
 (0)