Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 2 additions & 23 deletions src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -410,7 +410,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return h


class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model for 2D images with spatial tiling support.

Expand Down Expand Up @@ -486,27 +486,6 @@ def enable_tiling(
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: torch.Tensor):

batch_size, num_channels, height, width = x.shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -584,7 +584,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanImage-2.1 Refiner.
Expand Down Expand Up @@ -685,27 +685,6 @@ def enable_tiling(
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape

Expand Down
25 changes: 2 additions & 23 deletions src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -625,7 +625,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanVideo-1.5.
Expand Down Expand Up @@ -723,27 +723,6 @@ def enable_tiling(
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape

Expand Down
Loading