Skip to content

Commit 08a61f6

Browse files
authored
Merge branch 'main' into kandinsky2_2-xpu
2 parents 2b60562 + 01abfc8 commit 08a61f6

File tree

14 files changed

+277
-70
lines changed

14 files changed

+277
-70
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,3 +1704,11 @@ def get_alpha_scales(down_weight, key):
17041704
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
17051705

17061706
return converted_state_dict
1707+
1708+
1709+
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
1710+
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
1711+
raise ValueError("Invalid LoRA state dict for HiDream.")
1712+
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
1713+
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
1714+
return converted_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
_convert_hunyuan_video_lora_to_diffusers,
4444
_convert_kohya_flux_lora_to_diffusers,
4545
_convert_musubi_wan_lora_to_diffusers,
46+
_convert_non_diffusers_hidream_lora_to_diffusers,
4647
_convert_non_diffusers_lora_to_diffusers,
4748
_convert_non_diffusers_lumina2_lora_to_diffusers,
4849
_convert_non_diffusers_wan_lora_to_diffusers,
@@ -5371,7 +5372,6 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
53715372

53725373
@classmethod
53735374
@validate_hf_hub_args
5374-
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
53755375
def lora_state_dict(
53765376
cls,
53775377
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -5465,6 +5465,10 @@ def lora_state_dict(
54655465
logger.warning(warn_msg)
54665466
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
54675467

5468+
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
5469+
if is_non_diffusers_format:
5470+
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
5471+
54685472
return state_dict
54695473

54705474
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights

src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,19 @@ def __init__(
152152

153153
# 1. Latent and condition embedders
154154
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
155+
156+
# Framepack history projection embedder
157+
self.clean_x_embedder = None
158+
if has_clean_x_embedder:
159+
self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
160+
155161
self.context_embedder = HunyuanVideoTokenRefiner(
156162
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
157163
)
164+
165+
# Framepack image-conditioning embedder
166+
self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
167+
158168
self.time_text_embed = HunyuanVideoConditionEmbedding(
159169
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
160170
)
@@ -186,14 +196,7 @@ def __init__(
186196
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
187197
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
188198

189-
# Framepack specific modules
190-
self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
191-
192-
self.clean_x_embedder = None
193-
if has_clean_x_embedder:
194-
self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
195-
196-
self.use_gradient_checkpointing = False
199+
self.gradient_checkpointing = False
197200

198201
def forward(
199202
self,

src/diffusers/pipelines/ltx/pipeline_ltx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def __call__(
789789
]
790790
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
791791

792+
latents = latents.to(self.vae.dtype)
792793
video = self.vae.decode(latents, timestep, return_dict=False)[0]
793794
video = self.video_processor.postprocess_video(video, output_type=output_type)
794795

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2024 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from diffusers import HunyuanVideoFramepackTransformer3DModel
20+
from diffusers.utils.testing_utils import (
21+
enable_full_determinism,
22+
torch_device,
23+
)
24+
25+
from ..test_modeling_common import ModelTesterMixin
26+
27+
28+
enable_full_determinism()
29+
30+
31+
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
32+
model_class = HunyuanVideoFramepackTransformer3DModel
33+
main_input_name = "hidden_states"
34+
uses_custom_attn_processor = True
35+
model_split_percents = [0.5, 0.7, 0.9]
36+
37+
@property
38+
def dummy_input(self):
39+
batch_size = 1
40+
num_channels = 4
41+
num_frames = 3
42+
height = 4
43+
width = 4
44+
text_encoder_embedding_dim = 16
45+
image_encoder_embedding_dim = 16
46+
pooled_projection_dim = 8
47+
sequence_length = 12
48+
49+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
50+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
51+
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
52+
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
53+
image_embeds = torch.randn((batch_size, sequence_length, image_encoder_embedding_dim)).to(torch_device)
54+
indices_latents = torch.ones((3,)).to(torch_device)
55+
latents_clean = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
56+
indices_latents_clean = torch.ones((num_frames - 1,)).to(torch_device)
57+
latents_history_2x = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
58+
indices_latents_history_2x = torch.ones((num_frames - 1,)).to(torch_device)
59+
latents_history_4x = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to(
60+
torch_device
61+
)
62+
indices_latents_history_4x = torch.ones(((num_frames - 1) * 4,)).to(torch_device)
63+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
64+
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
65+
66+
return {
67+
"hidden_states": hidden_states,
68+
"timestep": timestep,
69+
"encoder_hidden_states": encoder_hidden_states,
70+
"pooled_projections": pooled_projections,
71+
"encoder_attention_mask": encoder_attention_mask,
72+
"guidance": guidance,
73+
"image_embeds": image_embeds,
74+
"indices_latents": indices_latents,
75+
"latents_clean": latents_clean,
76+
"indices_latents_clean": indices_latents_clean,
77+
"latents_history_2x": latents_history_2x,
78+
"indices_latents_history_2x": indices_latents_history_2x,
79+
"latents_history_4x": latents_history_4x,
80+
"indices_latents_history_4x": indices_latents_history_4x,
81+
}
82+
83+
@property
84+
def input_shape(self):
85+
return (4, 3, 4, 4)
86+
87+
@property
88+
def output_shape(self):
89+
return (4, 3, 4, 4)
90+
91+
def prepare_init_args_and_inputs_for_common(self):
92+
init_dict = {
93+
"in_channels": 4,
94+
"out_channels": 4,
95+
"num_attention_heads": 2,
96+
"attention_head_dim": 10,
97+
"num_layers": 1,
98+
"num_single_layers": 1,
99+
"num_refiner_layers": 1,
100+
"patch_size": 2,
101+
"patch_size_t": 1,
102+
"guidance_embeds": True,
103+
"text_embed_dim": 16,
104+
"pooled_projection_dim": 8,
105+
"rope_axes_dim": (2, 4, 4),
106+
"image_condition_type": None,
107+
"has_image_proj": True,
108+
"image_proj_dim": 16,
109+
"has_clean_x_embedder": True,
110+
}
111+
inputs_dict = self.dummy_input
112+
return init_dict, inputs_dict
113+
114+
def test_gradient_checkpointing_is_applied(self):
115+
expected_set = {"HunyuanVideoFramepackTransformer3DModel"}
116+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/pipelines/consisid/test_consisid.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
from diffusers import AutoencoderKLCogVideoX, ConsisIDPipeline, ConsisIDTransformer3DModel, DDIMScheduler
2525
from diffusers.utils import load_image
2626
from diffusers.utils.testing_utils import (
27+
backend_empty_cache,
2728
enable_full_determinism,
2829
numpy_cosine_similarity_distance,
29-
require_torch_gpu,
30+
require_torch_accelerator,
3031
slow,
3132
torch_device,
3233
)
@@ -316,19 +317,19 @@ def test_vae_tiling(self, expected_diff_max: float = 0.4):
316317

317318

318319
@slow
319-
@require_torch_gpu
320+
@require_torch_accelerator
320321
class ConsisIDPipelineIntegrationTests(unittest.TestCase):
321322
prompt = "A painting of a squirrel eating a burger."
322323

323324
def setUp(self):
324325
super().setUp()
325326
gc.collect()
326-
torch.cuda.empty_cache()
327+
backend_empty_cache(torch_device)
327328

328329
def tearDown(self):
329330
super().tearDown()
330331
gc.collect()
331-
torch.cuda.empty_cache()
332+
backend_empty_cache(torch_device)
332333

333334
def test_consisid(self):
334335
generator = torch.Generator("cpu").manual_seed(0)
@@ -338,8 +339,8 @@ def test_consisid(self):
338339

339340
prompt = self.prompt
340341
image = load_image("https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true")
341-
id_vit_hidden = [torch.ones([1, 2, 2])] * 1
342-
id_cond = torch.ones(1, 2)
342+
id_vit_hidden = [torch.ones([1, 577, 1024])] * 5
343+
id_cond = torch.ones(1, 1280)
343344

344345
videos = pipe(
345346
image=image,
@@ -357,5 +358,5 @@ def test_consisid(self):
357358
video = videos[0]
358359
expected_video = torch.randn(1, 16, 480, 720, 3).numpy()
359360

360-
max_diff = numpy_cosine_similarity_distance(video, expected_video)
361+
max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
361362
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

tests/pipelines/dit/test_dit.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@
2121

2222
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel, DPMSolverMultistepScheduler
2323
from diffusers.utils import is_xformers_available
24-
from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device
24+
from diffusers.utils.testing_utils import (
25+
backend_empty_cache,
26+
enable_full_determinism,
27+
load_numpy,
28+
nightly,
29+
numpy_cosine_similarity_distance,
30+
require_torch_accelerator,
31+
torch_device,
32+
)
2533

2634
from ..pipeline_params import (
2735
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS,
@@ -107,23 +115,23 @@ def test_xformers_attention_forwardGenerator_pass(self):
107115

108116

109117
@nightly
110-
@require_torch_gpu
118+
@require_torch_accelerator
111119
class DiTPipelineIntegrationTests(unittest.TestCase):
112120
def setUp(self):
113121
super().setUp()
114122
gc.collect()
115-
torch.cuda.empty_cache()
123+
backend_empty_cache(torch_device)
116124

117125
def tearDown(self):
118126
super().tearDown()
119127
gc.collect()
120-
torch.cuda.empty_cache()
128+
backend_empty_cache(torch_device)
121129

122130
def test_dit_256(self):
123131
generator = torch.manual_seed(0)
124132

125133
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
126-
pipe.to("cuda")
134+
pipe.to(torch_device)
127135

128136
words = ["vase", "umbrella", "white shark", "white wolf"]
129137
ids = pipe.get_label_ids(words)
@@ -139,7 +147,7 @@ def test_dit_256(self):
139147
def test_dit_512(self):
140148
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512")
141149
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
142-
pipe.to("cuda")
150+
pipe.to(torch_device)
143151

144152
words = ["vase", "umbrella"]
145153
ids = pipe.get_label_ids(words)
@@ -152,4 +160,7 @@ def test_dit_512(self):
152160
f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}_512.npy"
153161
)
154162

155-
assert np.abs((expected_image - image).max()) < 1e-1
163+
expected_slice = expected_image.flatten()
164+
output_slice = image.flatten()
165+
166+
assert numpy_cosine_similarity_distance(expected_slice, output_slice) < 1e-2

tests/pipelines/easyanimate/test_easyanimate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
FlowMatchEulerDiscreteScheduler,
2828
)
2929
from diffusers.utils.testing_utils import (
30+
backend_empty_cache,
3031
enable_full_determinism,
3132
numpy_cosine_similarity_distance,
32-
require_torch_gpu,
33+
require_torch_accelerator,
3334
slow,
3435
torch_device,
3536
)
@@ -256,19 +257,19 @@ def test_encode_prompt_works_in_isolation(self):
256257

257258

258259
@slow
259-
@require_torch_gpu
260+
@require_torch_accelerator
260261
class EasyAnimatePipelineIntegrationTests(unittest.TestCase):
261262
prompt = "A painting of a squirrel eating a burger."
262263

263264
def setUp(self):
264265
super().setUp()
265266
gc.collect()
266-
torch.cuda.empty_cache()
267+
backend_empty_cache(torch_device)
267268

268269
def tearDown(self):
269270
super().tearDown()
270271
gc.collect()
271-
torch.cuda.empty_cache()
272+
backend_empty_cache(torch_device)
272273

273274
def test_EasyAnimate(self):
274275
generator = torch.Generator("cpu").manual_seed(0)

tests/pipelines/mochi/test_mochi.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
enable_full_determinism,
2828
nightly,
2929
numpy_cosine_similarity_distance,
30-
require_big_gpu_with_torch_cuda,
31-
require_torch_gpu,
30+
require_big_accelerator,
31+
require_torch_accelerator,
3232
torch_device,
3333
)
3434

@@ -266,9 +266,9 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
266266

267267

268268
@nightly
269-
@require_torch_gpu
270-
@require_big_gpu_with_torch_cuda
271-
@pytest.mark.big_gpu_with_torch_cuda
269+
@require_torch_accelerator
270+
@require_big_accelerator
271+
@pytest.mark.big_accelerator
272272
class MochiPipelineIntegrationTests(unittest.TestCase):
273273
prompt = "A painting of a squirrel eating a burger."
274274

@@ -302,5 +302,5 @@ def test_mochi(self):
302302
video = videos[0]
303303
expected_video = torch.randn(1, 19, 480, 848, 3).numpy()
304304

305-
max_diff = numpy_cosine_similarity_distance(video, expected_video)
305+
max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
306306
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

0 commit comments

Comments
 (0)