Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,13 @@ def __init__(
self.gradient_checkpointing = False

def forward(self, x, feat_cache=None, feat_idx=[0]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
return self._gradient_checkpointing_func(self._decode, x, feat_cache, feat_idx)
else:
return self._decode(x, feat_cache, feat_idx)

def _decode(self, x, in_cache=None, feat_idx=[0]):
feat_cache = in_cache.copy()
## conv1
if feat_cache is not None:
idx = feat_idx[0]
Expand Down Expand Up @@ -653,7 +660,8 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
feat_idx[0] = 0
return x, feat_cache


class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Expand All @@ -665,7 +673,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for all models (such as downloading or saving).
"""

_supports_gradient_checkpointing = False
_supports_gradient_checkpointing = True

@register_to_config
def __init__(
Expand Down Expand Up @@ -799,9 +807,13 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out, self._feat_map = self.decoder(
x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx
)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out_, self._feat_map = self.decoder(
x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx
)
out = torch.cat([out, out_], 2)

out = torch.clamp(out, min=-1.0, max=1.0)
Expand Down
9 changes: 7 additions & 2 deletions tests/models/autoencoders/test_models_autoencoder_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Gradient checkpointing has not been implemented yet")
def test_gradient_checkpointing_is_applied(self):
pass
expected_set = {
"WanDecoder3d",
"WanEncoder3d",
"WanMidBlock",
"WanUpBlock",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@unittest.skip("Test not supported")
def test_forward_with_norm_groups(self):
Expand Down