|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -from typing import Tuple, Union |
| 16 | +from typing import Optional, Tuple, Union |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import torch.nn as nn |
20 | 20 | import torch.nn.functional as F |
21 | 21 |
|
22 | 22 | from ...configuration_utils import ConfigMixin, register_to_config |
23 | 23 | from ...loaders import FromOriginalModelMixin |
| 24 | +from ...utils.accelerate_utils import apply_forward_hook |
24 | 25 | from ..activations import get_activation |
25 | 26 | from ..attention_processor import SanaMultiscaleLinearAttention |
26 | 27 | from ..modeling_utils import ModelMixin |
27 | 28 | from ..normalization import RMSNorm, get_normalization |
28 | | -from .vae import DecoderOutput |
| 29 | +from .vae import DecoderOutput, EncoderOutput |
29 | 30 |
|
30 | 31 |
|
31 | 32 | class GLUMBConv(nn.Module): |
@@ -484,13 +485,148 @@ def __init__( |
484 | 485 | self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1) |
485 | 486 | self.temporal_compression_ratio = 1 |
486 | 487 |
|
487 | | - def encode(self, x: torch.Tensor) -> torch.Tensor: |
488 | | - x = self.encoder(x) |
489 | | - return x |
| 488 | + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension |
| 489 | + # to perform decoding of a single video latent at a time. |
| 490 | + self.use_slicing = False |
490 | 491 |
|
491 | | - def decode(self, x: torch.Tensor) -> torch.Tensor: |
492 | | - x = self.decoder(x) |
493 | | - return x |
| 492 | + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent |
| 493 | + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the |
| 494 | + # intermediate tiles together, the memory requirement can be lowered. |
| 495 | + self.use_tiling = False |
| 496 | + |
| 497 | + # The minimal tile height and width for spatial tiling to be used |
| 498 | + self.tile_sample_min_height = 512 |
| 499 | + self.tile_sample_min_width = 512 |
| 500 | + |
| 501 | + # The minimal distance between two spatial tiles |
| 502 | + self.tile_sample_stride_height = 448 |
| 503 | + self.tile_sample_stride_width = 448 |
| 504 | + |
| 505 | + def enable_tiling( |
| 506 | + self, |
| 507 | + tile_sample_min_height: Optional[int] = None, |
| 508 | + tile_sample_min_width: Optional[int] = None, |
| 509 | + tile_sample_stride_height: Optional[float] = None, |
| 510 | + tile_sample_stride_width: Optional[float] = None, |
| 511 | + ) -> None: |
| 512 | + r""" |
| 513 | + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to |
| 514 | + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow |
| 515 | + processing larger images. |
| 516 | +
|
| 517 | + Args: |
| 518 | + tile_sample_min_height (`int`, *optional*): |
| 519 | + The minimum height required for a sample to be separated into tiles across the height dimension. |
| 520 | + tile_sample_min_width (`int`, *optional*): |
| 521 | + The minimum width required for a sample to be separated into tiles across the width dimension. |
| 522 | + tile_sample_stride_height (`int`, *optional*): |
| 523 | + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are |
| 524 | + no tiling artifacts produced across the height dimension. |
| 525 | + tile_sample_stride_width (`int`, *optional*): |
| 526 | + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling |
| 527 | + artifacts produced across the width dimension. |
| 528 | + """ |
| 529 | + self.use_tiling = True |
| 530 | + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height |
| 531 | + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width |
| 532 | + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height |
| 533 | + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width |
| 534 | + |
| 535 | + def disable_tiling(self) -> None: |
| 536 | + r""" |
| 537 | + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing |
| 538 | + decoding in one step. |
| 539 | + """ |
| 540 | + self.use_tiling = False |
| 541 | + |
| 542 | + def enable_slicing(self) -> None: |
| 543 | + r""" |
| 544 | + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
| 545 | + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
| 546 | + """ |
| 547 | + self.use_slicing = True |
| 548 | + |
| 549 | + def disable_slicing(self) -> None: |
| 550 | + r""" |
| 551 | + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing |
| 552 | + decoding in one step. |
| 553 | + """ |
| 554 | + self.use_slicing = False |
| 555 | + |
| 556 | + def _encode(self, x: torch.Tensor) -> torch.Tensor: |
| 557 | + batch_size, num_channels, height, width = x.shape |
| 558 | + |
| 559 | + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): |
| 560 | + return self.tiled_encode(x, return_dict=False)[0] |
| 561 | + |
| 562 | + encoded = self.encoder(x) |
| 563 | + |
| 564 | + return encoded |
| 565 | + |
| 566 | + @apply_forward_hook |
| 567 | + def encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor: |
| 568 | + r""" |
| 569 | + Encode a batch of images into latents. |
| 570 | +
|
| 571 | + Args: |
| 572 | + x (`torch.Tensor`): Input batch of images. |
| 573 | + return_dict (`bool`, defaults to `True`): |
| 574 | + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. |
| 575 | +
|
| 576 | + Returns: |
| 577 | + The latent representations of the encoded videos. If `return_dict` is True, a |
| 578 | + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. |
| 579 | + """ |
| 580 | + if self.use_slicing and x.shape[0] > 1: |
| 581 | + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] |
| 582 | + encoded = torch.cat(encoded_slices) |
| 583 | + else: |
| 584 | + encoded = self._encode(x) |
| 585 | + |
| 586 | + if not return_dict: |
| 587 | + return (encoded,) |
| 588 | + return EncoderOutput(latent=encoded) |
| 589 | + |
| 590 | + def _decode(self, z: torch.Tensor) -> torch.Tensor: |
| 591 | + batch_size, num_channels, height, width = z.shape |
| 592 | + |
| 593 | + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): |
| 594 | + return self.tiled_decode(z, return_dict=False)[0] |
| 595 | + |
| 596 | + decoded = self.decoder(z) |
| 597 | + |
| 598 | + return decoded |
| 599 | + |
| 600 | + @apply_forward_hook |
| 601 | + def decode(self, z: torch.Tensor, return_dict: bool = True) -> torch.Tensor: |
| 602 | + r""" |
| 603 | + Decode a batch of images. |
| 604 | +
|
| 605 | + Args: |
| 606 | + z (`torch.Tensor`): Input batch of latent vectors. |
| 607 | + return_dict (`bool`, defaults to `True`): |
| 608 | + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. |
| 609 | +
|
| 610 | + Returns: |
| 611 | + [`~models.vae.DecoderOutput`] or `tuple`: |
| 612 | + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is |
| 613 | + returned. |
| 614 | + """ |
| 615 | + if self.use_slicing and z.size(0) > 1: |
| 616 | + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] |
| 617 | + decoded = torch.cat(decoded_slices) |
| 618 | + else: |
| 619 | + decoded = self._decode(z) |
| 620 | + |
| 621 | + if not return_dict: |
| 622 | + return (decoded,) |
| 623 | + return DecoderOutput(sample=decoded) |
| 624 | + |
| 625 | + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor: |
| 626 | + raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.") |
| 627 | + |
| 628 | + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: |
| 629 | + raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.") |
494 | 630 |
|
495 | 631 | def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor: |
496 | 632 | z = self.encode(sample) |
|
0 commit comments