Skip to content

Commit dab9132

Browse files
authored
Merge branch 'main' into hidream-followup
2 parents 12afa54 + b4be422 commit dab9132

File tree

9 files changed

+6546
-28
lines changed

9 files changed

+6546
-28
lines changed

examples/community/pipeline_controlnet_xl_kolors.py

Lines changed: 1355 additions & 0 deletions
Large diffs are not rendered by default.

examples/community/pipeline_controlnet_xl_kolors_img2img.py

Lines changed: 1557 additions & 0 deletions
Large diffs are not rendered by default.

examples/community/pipeline_controlnet_xl_kolors_inpaint.py

Lines changed: 1871 additions & 0 deletions
Large diffs are not rendered by default.

examples/community/pipeline_kolors_inpainting.py

Lines changed: 1728 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
433433
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
434434
if not is_sparse:
435435
# down_weight is copied to each split
436-
ait_sd.update({k: down_weight for k in ait_down_keys})
436+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
437437

438438
# up_weight is split to each split
439439
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -923,7 +923,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
923923
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
924924

925925
# down_weight is copied to each split
926-
ait_sd.update({k: down_weight for k in ait_down_keys})
926+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
927927

928928
# up_weight is split to each split
929929
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -800,17 +800,20 @@ def __call__(
800800
)
801801
height, width = control_image.shape[-2:]
802802

803-
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
804-
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
805-
806-
height_control_image, width_control_image = control_image.shape[2:]
807-
control_image = self._pack_latents(
808-
control_image,
809-
batch_size * num_images_per_prompt,
810-
num_channels_latents,
811-
height_control_image,
812-
width_control_image,
813-
)
803+
# xlab controlnet has a input_hint_block and instantx controlnet does not
804+
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
805+
if self.controlnet.input_hint_block is None:
806+
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
807+
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
808+
809+
height_control_image, width_control_image = control_image.shape[2:]
810+
control_image = self._pack_latents(
811+
control_image,
812+
batch_size * num_images_per_prompt,
813+
num_channels_latents,
814+
height_control_image,
815+
width_control_image,
816+
)
814817

815818
if control_mode is not None:
816819
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
@@ -819,7 +822,9 @@ def __call__(
819822
elif isinstance(self.controlnet, FluxMultiControlNetModel):
820823
control_images = []
821824

822-
for control_image_ in control_image:
825+
# xlab controlnet has a input_hint_block and instantx controlnet does not
826+
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
827+
for i, control_image_ in enumerate(control_image):
823828
control_image_ = self.prepare_image(
824829
image=control_image_,
825830
width=width,
@@ -831,17 +836,18 @@ def __call__(
831836
)
832837
height, width = control_image_.shape[-2:]
833838

834-
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
835-
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
839+
if self.controlnet.nets[0].input_hint_block is None:
840+
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
841+
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
836842

837-
height_control_image, width_control_image = control_image_.shape[2:]
838-
control_image_ = self._pack_latents(
839-
control_image_,
840-
batch_size * num_images_per_prompt,
841-
num_channels_latents,
842-
height_control_image,
843-
width_control_image,
844-
)
843+
height_control_image, width_control_image = control_image_.shape[2:]
844+
control_image_ = self._pack_latents(
845+
control_image_,
846+
batch_size * num_images_per_prompt,
847+
num_channels_latents,
848+
height_control_image,
849+
width_control_image,
850+
)
845851

846852
control_images.append(control_image_)
847853

@@ -955,6 +961,7 @@ def __call__(
955961
img_ids=latent_image_ids,
956962
joint_attention_kwargs=self.joint_attention_kwargs,
957963
return_dict=False,
964+
controlnet_blocks_repeat=controlnet_blocks_repeat,
958965
)[0]
959966

960967
latents_dtype = latents.dtype

src/diffusers/pipelines/pipeline_flax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def load_module(name, value):
469469
class_obj = import_flax_or_no_model(pipeline_module, class_name)
470470

471471
importable_classes = ALL_IMPORTABLE_CLASSES
472-
class_candidates = {c: class_obj for c in importable_classes.keys()}
472+
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
473473
else:
474474
# else we just import it from the library.
475475
library = importlib.import_module(library_name)

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,13 +341,13 @@ def get_class_obj_and_candidates(
341341
pipeline_module = getattr(pipelines, library_name)
342342

343343
class_obj = getattr(pipeline_module, class_name)
344-
class_candidates = {c: class_obj for c in importable_classes.keys()}
344+
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
345345
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
346346
# load custom component
347347
class_obj = get_class_from_dynamic_module(
348348
component_folder, module_file=library_name + ".py", class_name=class_name
349349
)
350-
class_candidates = {c: class_obj for c in importable_classes.keys()}
350+
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
351351
else:
352352
# else we just import it from the library.
353353
library = importlib.import_module(library_name)

tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def test_save_load_optional_components(self):
205205
# set all optional components to None and update pipeline config accordingly
206206
for optional_component in pipe._optional_components:
207207
setattr(pipe, optional_component, None)
208-
pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components})
208+
pipe.register_modules(**dict.fromkeys(pipe._optional_components))
209209

210210
inputs = self.get_dummy_inputs(torch_device)
211211
output = pipe(**inputs)[0]

0 commit comments

Comments
 (0)