Skip to content

Commit cc94647

Browse files
authored
Merge branch 'main' into pyramid-attention-broadcast
2 parents e4d8b12 + 980736b commit cc94647

File tree

5 files changed

+15
-3
lines changed

5 files changed

+15
-3
lines changed

examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def load_model_hook(models, input_dir):
765765
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
766766

767767
transformer_state_dict = {
768-
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
768+
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
769769
}
770770
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
771771
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")

src/diffusers/loaders/single_file_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186
"inpainting": 512,
187187
"inpainting_v2": 512,
188188
"controlnet": 512,
189+
"instruct-pix2pix": 512,
189190
"v2": 768,
190191
"v1": 512,
191192
}

src/diffusers/pipelines/latte/pipeline_latte.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ...utils import (
3232
BACKENDS_MAPPING,
3333
BaseOutput,
34+
deprecate,
3435
is_bs4_available,
3536
is_ftfy_available,
3637
is_torch_xla_available,
@@ -853,6 +854,13 @@ def __call__(
853854

854855
self._current_timestep = None
855856

857+
if output_type == "latents":
858+
deprecation_message = (
859+
"Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead."
860+
)
861+
deprecate("output_type_latents", "1.0.0", deprecation_message, standard_warn=False)
862+
output_type = "latent"
863+
856864
if not output_type == "latent":
857865
video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
858866
video = self.video_processor.postprocess_video(video=video, output_type=output_type)

tests/single_file/single_file_testing_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def download_diffusers_config(repo_id, tmpdir):
4747

4848

4949
class SDSingleFileTesterMixin:
50+
single_file_kwargs = {}
51+
5052
def _compare_component_configs(self, pipe, single_file_pipe):
5153
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
5254
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
@@ -154,7 +156,7 @@ def test_single_file_components_with_original_config_local_files_only(
154156
self._compare_component_configs(pipe, single_file_pipe)
155157

156158
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
157-
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None)
159+
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None, **self.single_file_kwargs)
158160
sf_pipe.unet.set_attn_processor(AttnProcessor())
159161
sf_pipe.enable_model_cpu_offload(device=torch_device)
160162

@@ -170,7 +172,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_d
170172

171173
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
172174

173-
assert max_diff < expected_max_diff
175+
assert max_diff < expected_max_diff, f"{image.flatten()} != {image_single_file.flatten()}"
174176

175177
def test_single_file_components_with_diffusers_config(
176178
self,

tests/single_file/test_stable_diffusion_single_file.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCas
132132
"https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/refs/heads/main/configs/generate.yaml"
133133
)
134134
repo_id = "timbrooks/instruct-pix2pix"
135+
single_file_kwargs = {"extract_ema": True}
135136

136137
def setUp(self):
137138
super().setUp()

0 commit comments

Comments
 (0)