diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index f95c4cf37475..e19911e44545 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -453,14 +453,14 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", def forward(self, x, feat_cache=None, feat_idx=[0]): # First residual block - x = self.resnets[0](x, feat_cache, feat_idx) + x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx) # Process through attention and residual blocks for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: x = attn(x) - x = resnet(x, feat_cache, feat_idx) + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) return x @@ -494,9 +494,9 @@ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample def forward(self, x, feat_cache=None, feat_idx=[0]): x_copy = x.clone() for resnet in self.resnets: - x = resnet(x, feat_cache, feat_idx) + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) if self.downsampler is not None: - x = self.downsampler(x, feat_cache, feat_idx) + x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx) return x + self.avg_shortcut(x_copy) @@ -598,12 +598,12 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): ## downsamples for layer in self.down_blocks: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = layer(x) ## middle - x = self.mid_block(x, feat_cache, feat_idx) + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) ## head x = self.norm_out(x) @@ -694,13 +694,13 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): for resnet in self.resnets: if feat_cache is not None: - x = resnet(x, feat_cache, feat_idx) + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = resnet(x) if self.upsampler is not None: if feat_cache is not None: - x = self.upsampler(x, feat_cache, feat_idx) + x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = self.upsampler(x) @@ -767,13 +767,13 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): """ for resnet in self.resnets: if feat_cache is not None: - x = resnet(x, feat_cache, feat_idx) + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = resnet(x) if self.upsamplers is not None: if feat_cache is not None: - x = self.upsamplers[0](x, feat_cache, feat_idx) + x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = self.upsamplers[0](x) return x @@ -885,11 +885,11 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): x = self.conv_in(x) ## middle - x = self.mid_block(x, feat_cache, feat_idx) + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) ## upsamples for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk) + x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk) ## head x = self.norm_out(x) @@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ _supports_gradient_checkpointing = False + # keys toignore when AlignDeviceHook moves inputs/outputs between devices + # these are shared mutable state modified in-place + _skip_keys = ["feat_cache", "feat_idx"] @register_to_config def __init__( diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ce26785f63ea..91daca1ad809 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -251,6 +251,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _repeated_blocks = [] _parallel_config = None _cp_plan = None + _skip_keys = None def __init__(self): super().__init__() diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index dd542145d3fa..2169700ceae0 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -866,6 +866,9 @@ def load_sub_model( # remove hooks remove_hook_from_module(loaded_sub_model, recurse=True) needs_offloading_to_cpu = device_map[""] == "cpu" + skip_keys = None + if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None: + skip_keys = loaded_sub_model._skip_keys if needs_offloading_to_cpu: dispatch_model( @@ -874,9 +877,10 @@ def load_sub_model( device_map=device_map, force_hooks=True, main_device=0, + skip_keys=skip_keys, ) else: - dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True) + dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys) return loaded_sub_model