Skip to content

Commit e5ddf02

Browse files
authored
Merge branch 'main' into hunyuanvideo_lora
2 parents c69674c + 0ed09a1 commit e5ddf02

File tree

4 files changed

+69
-6
lines changed

4 files changed

+69
-6
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,11 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
643643
old_state_dict,
644644
new_state_dict,
645645
old_key,
646-
[f"transformer.single_transformer_blocks.{block_num}.norm.linear"],
646+
[
647+
f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
648+
f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
649+
f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
650+
],
647651
)
648652

649653
if "down" in old_key:

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@
3535
)
3636
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
3737
from .flux import (
38+
FluxControlImg2ImgPipeline,
39+
FluxControlInpaintPipeline,
3840
FluxControlNetImg2ImgPipeline,
3941
FluxControlNetInpaintPipeline,
4042
FluxControlNetPipeline,
43+
FluxControlPipeline,
4144
FluxImg2ImgPipeline,
4245
FluxInpaintPipeline,
4346
FluxPipeline,
@@ -125,6 +128,7 @@
125128
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
126129
("auraflow", AuraFlowPipeline),
127130
("flux", FluxPipeline),
131+
("flux-control", FluxControlPipeline),
128132
("flux-controlnet", FluxControlNetPipeline),
129133
("lumina", LuminaText2ImgPipeline),
130134
("cogview3", CogView3PlusPipeline),
@@ -150,6 +154,7 @@
150154
("lcm", LatentConsistencyModelImg2ImgPipeline),
151155
("flux", FluxImg2ImgPipeline),
152156
("flux-controlnet", FluxControlNetImg2ImgPipeline),
157+
("flux-control", FluxControlImg2ImgPipeline),
153158
]
154159
)
155160

@@ -168,6 +173,7 @@
168173
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
169174
("flux", FluxInpaintPipeline),
170175
("flux-controlnet", FluxControlNetInpaintPipeline),
176+
("flux-control", FluxControlInpaintPipeline),
171177
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
172178
]
173179
)
@@ -401,16 +407,20 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
401407

402408
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
403409
orig_class_name = config["_class_name"]
410+
if "ControlPipeline" in orig_class_name:
411+
to_replace = "ControlPipeline"
412+
else:
413+
to_replace = "Pipeline"
404414

405415
if "controlnet" in kwargs:
406416
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
407-
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
417+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
408418
else:
409-
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
419+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
410420
if "enable_pag" in kwargs:
411421
enable_pag = kwargs.pop("enable_pag")
412422
if enable_pag:
413-
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
423+
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
414424

415425
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
416426

@@ -694,8 +704,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
694704

695705
# the `orig_class_name` can be:
696706
# `- *Pipeline` (for regular text-to-image checkpoint)
707+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
697708
# `- *Img2ImgPipeline` (for refiner checkpoint)
698-
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
709+
if "Img2Img" in orig_class_name:
710+
to_replace = "Img2ImgPipeline"
711+
elif "ControlPipeline" in orig_class_name:
712+
to_replace = "ControlPipeline"
713+
else:
714+
to_replace = "Pipeline"
699715

700716
if "controlnet" in kwargs:
701717
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
@@ -707,6 +723,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
707723
if enable_pag:
708724
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
709725

726+
if to_replace == "ControlPipeline":
727+
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
728+
710729
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
711730

712731
kwargs = {**load_config_kwargs, **kwargs}
@@ -994,8 +1013,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
9941013

9951014
# The `orig_class_name`` can be:
9961015
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
1016+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
9971017
# - or *Pipeline (for regular text-to-image checkpoint)
998-
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
1018+
if "Inpaint" in orig_class_name:
1019+
to_replace = "InpaintPipeline"
1020+
elif "ControlPipeline" in orig_class_name:
1021+
to_replace = "ControlPipeline"
1022+
else:
1023+
to_replace = "Pipeline"
9991024

10001025
if "controlnet" in kwargs:
10011026
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
@@ -1006,6 +1031,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
10061031
enable_pag = kwargs.pop("enable_pag")
10071032
if enable_pag:
10081033
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
1034+
if to_replace == "ControlPipeline":
1035+
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
10091036
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
10101037

10111038
kwargs = {**load_config_kwargs, **kwargs}

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import enum
1617
import fnmatch
1718
import importlib
1819
import inspect
@@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
811812
# in this case they are already instantiated in `kwargs`
812813
# extract them here
813814
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
815+
expected_types = pipeline_class._get_signature_types()
814816
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
815817
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
816818
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -833,6 +835,26 @@ def load_module(name, value):
833835

834836
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
835837

838+
for key in init_dict.keys():
839+
if key not in passed_class_obj:
840+
continue
841+
if "scheduler" in key:
842+
continue
843+
844+
class_obj = passed_class_obj[key]
845+
_expected_class_types = []
846+
for expected_type in expected_types[key]:
847+
if isinstance(expected_type, enum.EnumMeta):
848+
_expected_class_types.extend(expected_type.__members__.keys())
849+
else:
850+
_expected_class_types.append(expected_type.__name__)
851+
852+
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
853+
if not _is_valid_type:
854+
logger.warning(
855+
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
856+
)
857+
836858
# Special case: safety_checker must be loaded separately when using `from_flax`
837859
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
838860
raise NotImplementedError(

tests/pipelines/test_pipelines.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,16 @@ def test_pipe_same_device_id_offload(self):
18021802
sd.maybe_free_model_hooks()
18031803
assert sd._offload_gpu_id == 5
18041804

1805+
def test_wrong_model(self):
1806+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1807+
with self.assertRaises(ValueError) as error_context:
1808+
_ = StableDiffusionPipeline.from_pretrained(
1809+
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
1810+
)
1811+
1812+
assert "is of type" in str(error_context.exception)
1813+
assert "but should be" in str(error_context.exception)
1814+
18051815

18061816
@slow
18071817
@require_torch_gpu

0 commit comments

Comments
 (0)