|
29 | 29 | StableDiffusionXLControlNetPipeline, |
30 | 30 | ) |
31 | 31 | from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline |
32 | | -from .flux import FluxPipeline |
| 32 | +from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline |
33 | 33 | from .hunyuandit import HunyuanDiTPipeline |
34 | 34 | from .kandinsky import ( |
35 | 35 | KandinskyCombinedPipeline, |
|
108 | 108 | ("pixart-sigma-pag", PixArtSigmaPAGPipeline), |
109 | 109 | ("auraflow", AuraFlowPipeline), |
110 | 110 | ("flux", FluxPipeline), |
| 111 | + ("flux-controlnet", FluxControlNetPipeline), |
111 | 112 | ("lumina", LuminaText2ImgPipeline), |
112 | 113 | ] |
113 | 114 | ) |
|
126 | 127 | ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), |
127 | 128 | ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), |
128 | 129 | ("lcm", LatentConsistencyModelImg2ImgPipeline), |
| 130 | + ("flux", FluxImg2ImgPipeline), |
129 | 131 | ] |
130 | 132 | ) |
131 | 133 |
|
|
140 | 142 | ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), |
141 | 143 | ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), |
142 | 144 | ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), |
| 145 | + ("flux", FluxInpaintPipeline), |
143 | 146 | ] |
144 | 147 | ) |
145 | 148 |
|
@@ -660,12 +663,17 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): |
660 | 663 | config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) |
661 | 664 | orig_class_name = config["_class_name"] |
662 | 665 |
|
| 666 | + # the `orig_class_name` can be: |
| 667 | + # `- *Pipeline` (for regular text-to-image checkpoint) |
| 668 | + # `- *Img2ImgPipeline` (for refiner checkpoint) |
| 669 | + to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" |
| 670 | + |
663 | 671 | if "controlnet" in kwargs: |
664 | | - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") |
| 672 | + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) |
665 | 673 | if "enable_pag" in kwargs: |
666 | 674 | enable_pag = kwargs.pop("enable_pag") |
667 | 675 | if enable_pag: |
668 | | - orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") |
| 676 | + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) |
669 | 677 |
|
670 | 678 | image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) |
671 | 679 |
|
@@ -952,14 +960,17 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): |
952 | 960 | config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) |
953 | 961 | orig_class_name = config["_class_name"] |
954 | 962 |
|
| 963 | + # The `orig_class_name`` can be: |
| 964 | + # `- *InpaintPipeline` (for inpaint-specific checkpoint) |
| 965 | + # - or *Pipeline (for regular text-to-image checkpoint) |
| 966 | + to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" |
| 967 | + |
955 | 968 | if "controlnet" in kwargs: |
956 | | - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") |
| 969 | + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) |
957 | 970 | if "enable_pag" in kwargs: |
958 | 971 | enable_pag = kwargs.pop("enable_pag") |
959 | 972 | if enable_pag: |
960 | | - to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" |
961 | | - orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace) |
962 | | - |
| 973 | + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) |
963 | 974 | inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) |
964 | 975 |
|
965 | 976 | kwargs = {**load_config_kwargs, **kwargs} |
|
0 commit comments