Skip to content

Commit 4906b00

Browse files
authored
Merge branch 'huggingface:main' into create-gh-action-for-copyright
2 parents 003b05c + 4c723d8 commit 4906b00

File tree

5 files changed

+37
-20
lines changed

5 files changed

+37
-20
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -453,14 +453,14 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu",
453453

454454
def forward(self, x, feat_cache=None, feat_idx=[0]):
455455
# First residual block
456-
x = self.resnets[0](x, feat_cache, feat_idx)
456+
x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
457457

458458
# Process through attention and residual blocks
459459
for attn, resnet in zip(self.attentions, self.resnets[1:]):
460460
if attn is not None:
461461
x = attn(x)
462462

463-
x = resnet(x, feat_cache, feat_idx)
463+
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
464464

465465
return x
466466

@@ -494,9 +494,9 @@ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample
494494
def forward(self, x, feat_cache=None, feat_idx=[0]):
495495
x_copy = x.clone()
496496
for resnet in self.resnets:
497-
x = resnet(x, feat_cache, feat_idx)
497+
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
498498
if self.downsampler is not None:
499-
x = self.downsampler(x, feat_cache, feat_idx)
499+
x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
500500

501501
return x + self.avg_shortcut(x_copy)
502502

@@ -598,12 +598,12 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
598598
## downsamples
599599
for layer in self.down_blocks:
600600
if feat_cache is not None:
601-
x = layer(x, feat_cache, feat_idx)
601+
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
602602
else:
603603
x = layer(x)
604604

605605
## middle
606-
x = self.mid_block(x, feat_cache, feat_idx)
606+
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
607607

608608
## head
609609
x = self.norm_out(x)
@@ -694,13 +694,13 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
694694

695695
for resnet in self.resnets:
696696
if feat_cache is not None:
697-
x = resnet(x, feat_cache, feat_idx)
697+
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
698698
else:
699699
x = resnet(x)
700700

701701
if self.upsampler is not None:
702702
if feat_cache is not None:
703-
x = self.upsampler(x, feat_cache, feat_idx)
703+
x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
704704
else:
705705
x = self.upsampler(x)
706706

@@ -767,13 +767,13 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
767767
"""
768768
for resnet in self.resnets:
769769
if feat_cache is not None:
770-
x = resnet(x, feat_cache, feat_idx)
770+
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
771771
else:
772772
x = resnet(x)
773773

774774
if self.upsamplers is not None:
775775
if feat_cache is not None:
776-
x = self.upsamplers[0](x, feat_cache, feat_idx)
776+
x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
777777
else:
778778
x = self.upsamplers[0](x)
779779
return x
@@ -885,11 +885,11 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
885885
x = self.conv_in(x)
886886

887887
## middle
888-
x = self.mid_block(x, feat_cache, feat_idx)
888+
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
889889

890890
## upsamples
891891
for up_block in self.up_blocks:
892-
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
892+
x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
893893

894894
## head
895895
x = self.norm_out(x)
@@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
961961
"""
962962

963963
_supports_gradient_checkpointing = False
964+
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
965+
# these are shared mutable state modified in-place
966+
_skip_keys = ["feat_cache", "feat_idx"]
964967

965968
@register_to_config
966969
def __init__(

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
251251
_repeated_blocks = []
252252
_parallel_config = None
253253
_cp_plan = None
254+
_skip_keys = None
254255

255256
def __init__(self):
256257
super().__init__()

src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -744,11 +744,13 @@ def __call__(
744744
)
745745

746746
if negative_prompt_embeds_qwen is None:
747-
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt(
748-
prompt=negative_prompt,
749-
max_sequence_length=max_sequence_length,
750-
device=device,
751-
dtype=dtype,
747+
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
748+
self.encode_prompt(
749+
prompt=negative_prompt,
750+
max_sequence_length=max_sequence_length,
751+
device=device,
752+
dtype=dtype,
753+
)
752754
)
753755

754756
# 4. Prepare timesteps
@@ -780,8 +782,8 @@ def __call__(
780782
text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
781783

782784
negative_text_rope_pos = (
783-
torch.arange(negative_cu_seqlens.diff().max().item(), device=device)
784-
if negative_cu_seqlens is not None
785+
torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device)
786+
if negative_prompt_cu_seqlens is not None
785787
else None
786788
)
787789

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,9 @@ def load_sub_model(
866866
# remove hooks
867867
remove_hook_from_module(loaded_sub_model, recurse=True)
868868
needs_offloading_to_cpu = device_map[""] == "cpu"
869+
skip_keys = None
870+
if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None:
871+
skip_keys = loaded_sub_model._skip_keys
869872

870873
if needs_offloading_to_cpu:
871874
dispatch_model(
@@ -874,9 +877,10 @@ def load_sub_model(
874877
device_map=device_map,
875878
force_hooks=True,
876879
main_device=0,
880+
skip_keys=skip_keys,
877881
)
878882
else:
879-
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
883+
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys)
880884

881885
return loaded_sub_model
882886

tests/pipelines/stable_cascade/test_stable_cascade_prior.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
import unittest
1818

1919
import numpy as np
20+
import pytest
2021
import torch
2122
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
2223

2324
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
2425
from diffusers.models import StableCascadeUNet
26+
from diffusers.utils import is_transformers_version
2527
from diffusers.utils.import_utils import is_peft_available
2628

2729
from ...testing_utils import (
@@ -154,6 +156,11 @@ def get_dummy_inputs(self, device, seed=0):
154156
}
155157
return inputs
156158

159+
@pytest.mark.xfail(
160+
condition=is_transformers_version(">=", "4.57.1"),
161+
reason="Test fails with the latest transformers version",
162+
strict=False,
163+
)
157164
def test_wuerstchen_prior(self):
158165
device = "cpu"
159166

0 commit comments

Comments
 (0)