3535)
3636from .deepfloyd_if import IFImg2ImgPipeline , IFInpaintingPipeline , IFPipeline
3737from .flux import (
38+ FluxControlImg2ImgPipeline ,
39+ FluxControlInpaintPipeline ,
3840 FluxControlNetImg2ImgPipeline ,
3941 FluxControlNetInpaintPipeline ,
4042 FluxControlNetPipeline ,
43+ FluxControlPipeline ,
4144 FluxImg2ImgPipeline ,
4245 FluxInpaintPipeline ,
4346 FluxPipeline ,
125128 ("pixart-sigma-pag" , PixArtSigmaPAGPipeline ),
126129 ("auraflow" , AuraFlowPipeline ),
127130 ("flux" , FluxPipeline ),
131+ ("flux-control" , FluxControlPipeline ),
128132 ("flux-controlnet" , FluxControlNetPipeline ),
129133 ("lumina" , LuminaText2ImgPipeline ),
130134 ("cogview3" , CogView3PlusPipeline ),
150154 ("lcm" , LatentConsistencyModelImg2ImgPipeline ),
151155 ("flux" , FluxImg2ImgPipeline ),
152156 ("flux-controlnet" , FluxControlNetImg2ImgPipeline ),
157+ ("flux-control" , FluxControlImg2ImgPipeline ),
153158 ]
154159)
155160
168173 ("stable-diffusion-xl-pag" , StableDiffusionXLPAGInpaintPipeline ),
169174 ("flux" , FluxInpaintPipeline ),
170175 ("flux-controlnet" , FluxControlNetInpaintPipeline ),
176+ ("flux-control" , FluxControlInpaintPipeline ),
171177 ("stable-diffusion-pag" , StableDiffusionPAGInpaintPipeline ),
172178 ]
173179)
@@ -401,16 +407,20 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
401407
402408 config = cls .load_config (pretrained_model_or_path , ** load_config_kwargs )
403409 orig_class_name = config ["_class_name" ]
410+ if "ControlPipeline" in orig_class_name :
411+ to_replace = "ControlPipeline"
412+ else :
413+ to_replace = "Pipeline"
404414
405415 if "controlnet" in kwargs :
406416 if isinstance (kwargs ["controlnet" ], ControlNetUnionModel ):
407- orig_class_name = config ["_class_name" ].replace ("Pipeline" , "ControlNetUnionPipeline" )
417+ orig_class_name = config ["_class_name" ].replace (to_replace , "ControlNetUnionPipeline" )
408418 else :
409- orig_class_name = config ["_class_name" ].replace ("Pipeline" , "ControlNetPipeline" )
419+ orig_class_name = config ["_class_name" ].replace (to_replace , "ControlNetPipeline" )
410420 if "enable_pag" in kwargs :
411421 enable_pag = kwargs .pop ("enable_pag" )
412422 if enable_pag :
413- orig_class_name = orig_class_name .replace ("Pipeline" , "PAGPipeline" )
423+ orig_class_name = orig_class_name .replace (to_replace , "PAGPipeline" )
414424
415425 text_2_image_cls = _get_task_class (AUTO_TEXT2IMAGE_PIPELINES_MAPPING , orig_class_name )
416426
@@ -694,8 +704,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
694704
695705 # the `orig_class_name` can be:
696706 # `- *Pipeline` (for regular text-to-image checkpoint)
707+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
697708 # `- *Img2ImgPipeline` (for refiner checkpoint)
698- to_replace = "Img2ImgPipeline" if "Img2Img" in config ["_class_name" ] else "Pipeline"
709+ if "Img2Img" in orig_class_name :
710+ to_replace = "Img2ImgPipeline"
711+ elif "ControlPipeline" in orig_class_name :
712+ to_replace = "ControlPipeline"
713+ else :
714+ to_replace = "Pipeline"
699715
700716 if "controlnet" in kwargs :
701717 if isinstance (kwargs ["controlnet" ], ControlNetUnionModel ):
@@ -707,6 +723,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
707723 if enable_pag :
708724 orig_class_name = orig_class_name .replace (to_replace , "PAG" + to_replace )
709725
726+ if to_replace == "ControlPipeline" :
727+ orig_class_name = orig_class_name .replace (to_replace , "ControlImg2ImgPipeline" )
728+
710729 image_2_image_cls = _get_task_class (AUTO_IMAGE2IMAGE_PIPELINES_MAPPING , orig_class_name )
711730
712731 kwargs = {** load_config_kwargs , ** kwargs }
@@ -994,8 +1013,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
9941013
9951014 # The `orig_class_name`` can be:
9961015 # `- *InpaintPipeline` (for inpaint-specific checkpoint)
1016+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
9971017 # - or *Pipeline (for regular text-to-image checkpoint)
998- to_replace = "InpaintPipeline" if "Inpaint" in config ["_class_name" ] else "Pipeline"
1018+ if "Inpaint" in orig_class_name :
1019+ to_replace = "InpaintPipeline"
1020+ elif "ControlPipeline" in orig_class_name :
1021+ to_replace = "ControlPipeline"
1022+ else :
1023+ to_replace = "Pipeline"
9991024
10001025 if "controlnet" in kwargs :
10011026 if isinstance (kwargs ["controlnet" ], ControlNetUnionModel ):
@@ -1006,6 +1031,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
10061031 enable_pag = kwargs .pop ("enable_pag" )
10071032 if enable_pag :
10081033 orig_class_name = orig_class_name .replace (to_replace , "PAG" + to_replace )
1034+ if to_replace == "ControlPipeline" :
1035+ orig_class_name = orig_class_name .replace (to_replace , "ControlInpaintPipeline" )
10091036 inpainting_cls = _get_task_class (AUTO_INPAINT_PIPELINES_MAPPING , orig_class_name )
10101037
10111038 kwargs = {** load_config_kwargs , ** kwargs }
0 commit comments