|
29 | 29 | StableDiffusionXLControlNetPipeline,
|
30 | 30 | )
|
31 | 31 | from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
32 |
| -from .flux import FluxPipeline |
| 32 | +from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline |
33 | 33 | from .hunyuandit import HunyuanDiTPipeline
|
34 | 34 | from .kandinsky import (
|
35 | 35 | KandinskyCombinedPipeline,
|
|
108 | 108 | ("pixart-sigma-pag", PixArtSigmaPAGPipeline),
|
109 | 109 | ("auraflow", AuraFlowPipeline),
|
110 | 110 | ("flux", FluxPipeline),
|
| 111 | + ("flux-controlnet", FluxControlNetPipeline), |
111 | 112 | ("lumina", LuminaText2ImgPipeline),
|
112 | 113 | ]
|
113 | 114 | )
|
|
126 | 127 | ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
|
127 | 128 | ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
|
128 | 129 | ("lcm", LatentConsistencyModelImg2ImgPipeline),
|
| 130 | + ("flux", FluxImg2ImgPipeline), |
129 | 131 | ]
|
130 | 132 | )
|
131 | 133 |
|
|
140 | 142 | ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
|
141 | 143 | ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
|
142 | 144 | ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
|
| 145 | + ("flux", FluxInpaintPipeline), |
143 | 146 | ]
|
144 | 147 | )
|
145 | 148 |
|
@@ -660,12 +663,17 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
|
660 | 663 | config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
661 | 664 | orig_class_name = config["_class_name"]
|
662 | 665 |
|
| 666 | + # the `orig_class_name` can be: |
| 667 | + # `- *Pipeline` (for regular text-to-image checkpoint) |
| 668 | + # `- *Img2ImgPipeline` (for refiner checkpoint) |
| 669 | + to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" |
| 670 | + |
663 | 671 | if "controlnet" in kwargs:
|
664 |
| - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") |
| 672 | + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) |
665 | 673 | if "enable_pag" in kwargs:
|
666 | 674 | enable_pag = kwargs.pop("enable_pag")
|
667 | 675 | if enable_pag:
|
668 |
| - orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") |
| 676 | + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) |
669 | 677 |
|
670 | 678 | image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
671 | 679 |
|
@@ -952,14 +960,17 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
|
952 | 960 | config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
953 | 961 | orig_class_name = config["_class_name"]
|
954 | 962 |
|
| 963 | + # The `orig_class_name`` can be: |
| 964 | + # `- *InpaintPipeline` (for inpaint-specific checkpoint) |
| 965 | + # - or *Pipeline (for regular text-to-image checkpoint) |
| 966 | + to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" |
| 967 | + |
955 | 968 | if "controlnet" in kwargs:
|
956 |
| - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") |
| 969 | + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) |
957 | 970 | if "enable_pag" in kwargs:
|
958 | 971 | enable_pag = kwargs.pop("enable_pag")
|
959 | 972 | if enable_pag:
|
960 |
| - to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" |
961 |
| - orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace) |
962 |
| - |
| 973 | + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) |
963 | 974 | inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
|
964 | 975 |
|
965 | 976 | kwargs = {**load_config_kwargs, **kwargs}
|
|
0 commit comments