Skip to content

Commit 6f29e2a

Browse files
committed
add standard autoencoder methods
1 parent 31f9fc6 commit 6f29e2a

File tree

3 files changed

+157
-9
lines changed

3 files changed

+157
-9
lines changed

src/diffusers/loaders/single_file_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2208,7 +2208,6 @@ def swap_scale_shift(weight):
22082208

22092209
def create_autoencoder_dc_config_from_original(original_config, checkpoint, **kwargs):
22102210
model_name = original_config.get("model_name", "dc-ae-f32c32-sana-1.0")
2211-
print("trying:", model_name)
22122211

22132212
if model_name in ["dc-ae-f32c32-sana-1.0"]:
22142213
config = {

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 144 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,20 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Tuple, Union
16+
from typing import Optional, Tuple, Union
1717

1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin
24+
from ...utils.accelerate_utils import apply_forward_hook
2425
from ..activations import get_activation
2526
from ..attention_processor import SanaMultiscaleLinearAttention
2627
from ..modeling_utils import ModelMixin
2728
from ..normalization import RMSNorm, get_normalization
28-
from .vae import DecoderOutput
29+
from .vae import DecoderOutput, EncoderOutput
2930

3031

3132
class GLUMBConv(nn.Module):
@@ -484,13 +485,148 @@ def __init__(
484485
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
485486
self.temporal_compression_ratio = 1
486487

487-
def encode(self, x: torch.Tensor) -> torch.Tensor:
488-
x = self.encoder(x)
489-
return x
488+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
489+
# to perform decoding of a single video latent at a time.
490+
self.use_slicing = False
490491

491-
def decode(self, x: torch.Tensor) -> torch.Tensor:
492-
x = self.decoder(x)
493-
return x
492+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
493+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
494+
# intermediate tiles together, the memory requirement can be lowered.
495+
self.use_tiling = False
496+
497+
# The minimal tile height and width for spatial tiling to be used
498+
self.tile_sample_min_height = 512
499+
self.tile_sample_min_width = 512
500+
501+
# The minimal distance between two spatial tiles
502+
self.tile_sample_stride_height = 448
503+
self.tile_sample_stride_width = 448
504+
505+
def enable_tiling(
506+
self,
507+
tile_sample_min_height: Optional[int] = None,
508+
tile_sample_min_width: Optional[int] = None,
509+
tile_sample_stride_height: Optional[float] = None,
510+
tile_sample_stride_width: Optional[float] = None,
511+
) -> None:
512+
r"""
513+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
514+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
515+
processing larger images.
516+
517+
Args:
518+
tile_sample_min_height (`int`, *optional*):
519+
The minimum height required for a sample to be separated into tiles across the height dimension.
520+
tile_sample_min_width (`int`, *optional*):
521+
The minimum width required for a sample to be separated into tiles across the width dimension.
522+
tile_sample_stride_height (`int`, *optional*):
523+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
524+
no tiling artifacts produced across the height dimension.
525+
tile_sample_stride_width (`int`, *optional*):
526+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
527+
artifacts produced across the width dimension.
528+
"""
529+
self.use_tiling = True
530+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
531+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
532+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
533+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
534+
535+
def disable_tiling(self) -> None:
536+
r"""
537+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
538+
decoding in one step.
539+
"""
540+
self.use_tiling = False
541+
542+
def enable_slicing(self) -> None:
543+
r"""
544+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
545+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
546+
"""
547+
self.use_slicing = True
548+
549+
def disable_slicing(self) -> None:
550+
r"""
551+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
552+
decoding in one step.
553+
"""
554+
self.use_slicing = False
555+
556+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
557+
batch_size, num_channels, height, width = x.shape
558+
559+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
560+
return self.tiled_encode(x, return_dict=False)[0]
561+
562+
encoded = self.encoder(x)
563+
564+
return encoded
565+
566+
@apply_forward_hook
567+
def encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
568+
r"""
569+
Encode a batch of images into latents.
570+
571+
Args:
572+
x (`torch.Tensor`): Input batch of images.
573+
return_dict (`bool`, defaults to `True`):
574+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
575+
576+
Returns:
577+
The latent representations of the encoded videos. If `return_dict` is True, a
578+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
579+
"""
580+
if self.use_slicing and x.shape[0] > 1:
581+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
582+
encoded = torch.cat(encoded_slices)
583+
else:
584+
encoded = self._encode(x)
585+
586+
if not return_dict:
587+
return (encoded,)
588+
return EncoderOutput(latent=encoded)
589+
590+
def _decode(self, z: torch.Tensor) -> torch.Tensor:
591+
batch_size, num_channels, height, width = z.shape
592+
593+
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
594+
return self.tiled_decode(z, return_dict=False)[0]
595+
596+
decoded = self.decoder(z)
597+
598+
return decoded
599+
600+
@apply_forward_hook
601+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
602+
r"""
603+
Decode a batch of images.
604+
605+
Args:
606+
z (`torch.Tensor`): Input batch of latent vectors.
607+
return_dict (`bool`, defaults to `True`):
608+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
609+
610+
Returns:
611+
[`~models.vae.DecoderOutput`] or `tuple`:
612+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
613+
returned.
614+
"""
615+
if self.use_slicing and z.size(0) > 1:
616+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
617+
decoded = torch.cat(decoded_slices)
618+
else:
619+
decoded = self._decode(z)
620+
621+
if not return_dict:
622+
return (decoded,)
623+
return DecoderOutput(sample=decoded)
624+
625+
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
626+
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
627+
628+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
629+
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
494630

495631
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
496632
z = self.encode(sample)

src/diffusers/models/autoencoders/vae.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@
3030
)
3131

3232

33+
@dataclass
34+
class EncoderOutput(BaseOutput):
35+
r"""
36+
Output of encoding method.
37+
38+
Args:
39+
latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`):
40+
The encoded latent.
41+
"""
42+
43+
latent: torch.Tensor
44+
45+
3346
@dataclass
3447
class DecoderOutput(BaseOutput):
3548
r"""

0 commit comments

Comments
 (0)