Skip to content

Commit 8cdcdd9

Browse files
authored
add flux inpaint + img2img + controlnet to auto pipeline (#9367)
1 parent d269cc8 commit 8cdcdd9

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
StableDiffusionXLControlNetPipeline,
3030
)
3131
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
32-
from .flux import FluxPipeline
32+
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline
3333
from .hunyuandit import HunyuanDiTPipeline
3434
from .kandinsky import (
3535
KandinskyCombinedPipeline,
@@ -108,6 +108,7 @@
108108
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
109109
("auraflow", AuraFlowPipeline),
110110
("flux", FluxPipeline),
111+
("flux-controlnet", FluxControlNetPipeline),
111112
("lumina", LuminaText2ImgPipeline),
112113
]
113114
)
@@ -126,6 +127,7 @@
126127
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
127128
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
128129
("lcm", LatentConsistencyModelImg2ImgPipeline),
130+
("flux", FluxImg2ImgPipeline),
129131
]
130132
)
131133

@@ -140,6 +142,7 @@
140142
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
141143
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
142144
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
145+
("flux", FluxInpaintPipeline),
143146
]
144147
)
145148

@@ -660,12 +663,17 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
660663
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
661664
orig_class_name = config["_class_name"]
662665

666+
# the `orig_class_name` can be:
667+
# `- *Pipeline` (for regular text-to-image checkpoint)
668+
# `- *Img2ImgPipeline` (for refiner checkpoint)
669+
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
670+
663671
if "controlnet" in kwargs:
664-
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
672+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
665673
if "enable_pag" in kwargs:
666674
enable_pag = kwargs.pop("enable_pag")
667675
if enable_pag:
668-
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
676+
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
669677

670678
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
671679

@@ -952,14 +960,17 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
952960
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
953961
orig_class_name = config["_class_name"]
954962

963+
# The `orig_class_name`` can be:
964+
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
965+
# - or *Pipeline (for regular text-to-image checkpoint)
966+
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
967+
955968
if "controlnet" in kwargs:
956-
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
969+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
957970
if "enable_pag" in kwargs:
958971
enable_pag = kwargs.pop("enable_pag")
959972
if enable_pag:
960-
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
961-
orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace)
962-
973+
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
963974
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
964975

965976
kwargs = {**load_config_kwargs, **kwargs}

tests/pipelines/test_pipelines_auto.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,32 @@ def test_from_pretrained_img2img(self):
235235
pipe = AutoPipelineForImage2Image.from_pretrained(repo)
236236
assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline"
237237

238+
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
239+
pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet)
240+
assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline"
241+
242+
pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True)
243+
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline"
244+
245+
pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True)
246+
assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline"
247+
248+
def test_from_pretrained_img2img_refiner(self):
249+
repo = "hf-internal-testing/tiny-stable-diffusion-xl-refiner-pipe"
250+
251+
pipe = AutoPipelineForImage2Image.from_pretrained(repo)
252+
assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline"
253+
254+
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
255+
pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet)
256+
assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline"
257+
238258
pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True)
239259
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline"
240260

261+
pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True)
262+
assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline"
263+
241264
def test_from_pipe_pag_img2img(self):
242265
# test from tableDiffusionXLPAGImg2ImgPipeline
243266
pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
@@ -265,6 +288,16 @@ def test_from_pretrained_inpaint(self):
265288
pipe_pag = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True)
266289
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline"
267290

291+
def test_from_pretrained_inpaint_from_inpaint(self):
292+
repo = "hf-internal-testing/tiny-stable-diffusion-xl-inpaint-pipe"
293+
294+
pipe = AutoPipelineForInpainting.from_pretrained(repo)
295+
assert pipe.__class__.__name__ == "StableDiffusionXLInpaintPipeline"
296+
297+
# make sure you can use pag with inpaint-specific pipeline
298+
pipe = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True)
299+
assert pipe.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline"
300+
268301
def test_from_pipe_pag_inpaint(self):
269302
# test from tableDiffusionXLPAGInpaintPipeline
270303
pipe = AutoPipelineForInpainting.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")

0 commit comments

Comments
 (0)